util/winutil: consolidate interface specific registry keys

Code movement to allow reuse in a follow up PR.

Updates #1659

Signed-off-by: Maisem Ali <maisem@tailscale.com>
pull/5365/head
Maisem Ali 2 years ago committed by Maisem Ali
parent 23f37b05a3
commit 545639ee44

@ -20,12 +20,10 @@ import (
"tailscale.com/envknob" "tailscale.com/envknob"
"tailscale.com/types/logger" "tailscale.com/types/logger"
"tailscale.com/util/dnsname" "tailscale.com/util/dnsname"
"tailscale.com/util/winutil"
) )
const ( const (
ipv4RegBase = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters`
ipv6RegBase = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters`
versionKey = `SOFTWARE\Microsoft\Windows NT\CurrentVersion` versionKey = `SOFTWARE\Microsoft\Windows NT\CurrentVersion`
) )
@ -59,24 +57,15 @@ func NewOSConfigurator(logf logger.Logf, interfaceName string) (OSConfigurator,
return ret, nil return ret, nil
} }
// keyOpenTimeout is how long we wait for a registry key to func (m windowsManager) openInterfaceKey(pfx winutil.RegistryPathPrefix) (registry.Key, error) {
// appear. For some reason, registry keys tied to ephemeral interfaces path := pfx.WithSuffix(m.guid)
// can take a long while to appear after interface creation, and we key, err := winutil.OpenKeyWait(registry.LOCAL_MACHINE, path, registry.SET_VALUE)
// can end up racing with that.
const keyOpenTimeout = 20 * time.Second
func (m windowsManager) openKey(path string) (registry.Key, error) {
key, err := openKeyWait(registry.LOCAL_MACHINE, path, registry.SET_VALUE, keyOpenTimeout)
if err != nil { if err != nil {
return 0, fmt.Errorf("opening %s: %w", path, err) return 0, fmt.Errorf("opening %s: %w", path, err)
} }
return key, nil return key, nil
} }
func (m windowsManager) ifPath(basePath string) string {
return fmt.Sprintf(`%s\Interfaces\%s`, basePath, m.guid)
}
func delValue(key registry.Key, name string) error { func delValue(key registry.Key, name string) error {
if err := key.DeleteValue(name); err != nil && err != registry.ErrNotExist { if err := key.DeleteValue(name); err != nil && err != registry.ErrNotExist {
return err return err
@ -134,7 +123,7 @@ func (m windowsManager) setPrimaryDNS(resolvers []netip.Addr, domains []dnsname.
domStrs = append(domStrs, dom.WithoutTrailingDot()) domStrs = append(domStrs, dom.WithoutTrailingDot())
} }
key4, err := m.openKey(m.ifPath(ipv4RegBase)) key4, err := m.openInterfaceKey(winutil.IPv4TCPIPInterfacePrefix)
if err != nil { if err != nil {
return err return err
} }
@ -156,7 +145,7 @@ func (m windowsManager) setPrimaryDNS(resolvers []netip.Addr, domains []dnsname.
return err return err
} }
key6, err := m.openKey(m.ifPath(ipv6RegBase)) key6, err := m.openInterfaceKey(winutil.IPv6TCPIPInterfacePrefix)
if err != nil { if err != nil {
return err return err
} }
@ -308,25 +297,26 @@ func (m windowsManager) Close() error {
// Windows DHCP client from sending dynamic DNS updates for our interface to // Windows DHCP client from sending dynamic DNS updates for our interface to
// AD domain controllers. // AD domain controllers.
func (m windowsManager) disableDynamicUpdates() error { func (m windowsManager) disableDynamicUpdates() error {
setRegValue := func(regBase string) error { if err := m.setSingleDWORD(winutil.IPv4TCPIPInterfacePrefix, "EnableDNSUpdate", 0); err != nil {
key, err := m.openKey(m.ifPath(regBase)) return err
if err != nil {
return err
}
defer key.Close()
return key.SetDWordValue("DisableDynamicUpdate", 1)
} }
if err := m.setSingleDWORD(winutil.IPv6TCPIPInterfacePrefix, "EnableDNSUpdate", 0); err != nil {
for _, regBase := range []string{ipv4RegBase, ipv6RegBase} { return err
if err := setRegValue(regBase); err != nil {
return err
}
} }
return nil return nil
} }
// setSingleDWORD opens the Registry Key in HKLM for the interface associated
// with the windowsManager and sets the "keyPrefix\value" to data.
func (m windowsManager) setSingleDWORD(prefix winutil.RegistryPathPrefix, value string, data uint32) error {
k, err := m.openInterfaceKey(prefix)
if err != nil {
return err
}
defer k.Close()
return k.SetDWordValue(value, data)
}
func (m windowsManager) GetBaseConfig() (OSConfig, error) { func (m windowsManager) GetBaseConfig() (OSConfig, error) {
resolvers, err := m.getBasePrimaryResolver() resolvers, err := m.getBasePrimaryResolver()
if err != nil { if err != nil {

@ -339,11 +339,12 @@ func deleteFakeGPKey(t *testing.T) {
} }
func createFakeInterfaceKey(t *testing.T, guid windows.GUID) (func(), error) { func createFakeInterfaceKey(t *testing.T, guid windows.GUID) (func(), error) {
basePaths := []string{ipv4RegBase, ipv6RegBase} basePaths := []winutil.RegistryPathPrefix{winutil.IPv4TCPIPInterfacePrefix, winutil.IPv6TCPIPInterfacePrefix}
keyPaths := make([]string, 0, len(basePaths)) keyPaths := make([]string, 0, len(basePaths))
guidStr := guid.String()
for _, basePath := range basePaths { for _, basePath := range basePaths {
keyPath := fmt.Sprintf(`%s\Interfaces\%s`, basePath, guid) keyPath := string(basePath.WithSuffix(guidStr))
key, _, err := registry.CreateKey(registry.LOCAL_MACHINE, keyPath, registry.SET_VALUE) key, _, err := registry.CreateKey(registry.LOCAL_MACHINE, keyPath, registry.SET_VALUE)
if err != nil { if err != nil {
return nil, err return nil, err

@ -1,76 +0,0 @@
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//
// The code in this file originates from https://git.zx2c4.com/wireguard-go:
// Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
// Copying license: https://git.zx2c4.com/wireguard-go/tree/COPYING
package dns
import (
"fmt"
"runtime"
"strings"
"time"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/registry"
)
func openKeyWait(k registry.Key, path string, access uint32, timeout time.Duration) (registry.Key, error) {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
deadline := time.Now().Add(timeout)
pathSpl := strings.Split(path, "\\")
for i := 0; ; i++ {
keyName := pathSpl[i]
isLast := i+1 == len(pathSpl)
event, err := windows.CreateEvent(nil, 0, 0, nil)
if err != nil {
return 0, fmt.Errorf("windows.CreateEvent: %v", err)
}
defer windows.CloseHandle(event)
var key registry.Key
for {
err = windows.RegNotifyChangeKeyValue(windows.Handle(k), false, windows.REG_NOTIFY_CHANGE_NAME, event, true)
if err != nil {
return 0, fmt.Errorf("windows.RegNotifyChangeKeyValue: %v", err)
}
var accessFlags uint32
if isLast {
accessFlags = access
} else {
accessFlags = registry.NOTIFY
}
key, err = registry.OpenKey(k, keyName, accessFlags)
if err == windows.ERROR_FILE_NOT_FOUND || err == windows.ERROR_PATH_NOT_FOUND {
timeout := time.Until(deadline) / time.Millisecond
if timeout < 0 {
timeout = 0
}
s, err := windows.WaitForSingleObject(event, uint32(timeout))
if err != nil {
return 0, fmt.Errorf("windows.WaitForSingleObject: %v", err)
}
if s == uint32(windows.WAIT_TIMEOUT) { // windows.WAIT_TIMEOUT status const is misclassified as error in golang.org/x/sys/windows
return 0, fmt.Errorf("timeout waiting for registry key")
}
} else if err != nil {
return 0, fmt.Errorf("registry.OpenKey(%v): %v", path, err)
} else {
if isLast {
return key, nil
}
defer key.Close()
break
}
}
k = key
}
}

@ -10,7 +10,9 @@ import (
"log" "log"
"os/exec" "os/exec"
"runtime" "runtime"
"strings"
"syscall" "syscall"
"time"
"unsafe" "unsafe"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
@ -391,3 +393,93 @@ func IsCurrentProcessElevated() bool {
return token.IsElevated() return token.IsElevated()
} }
// keyOpenTimeout is how long we wait for a registry key to appear. For some
// reason, registry keys tied to ephemeral interfaces can take a long while to
// appear after interface creation, and we can end up racing with that.
const keyOpenTimeout = 20 * time.Second
// RegistryPath represents a path inside a root registry.Key.
type RegistryPath string
// RegistryPathPrefix specifies a RegistryPath prefix that must be suffixed with
// another RegistryPath to make a valid RegistryPath.
type RegistryPathPrefix string
// WithSuffix returns a RegistryPath with the given suffix appended.
func (p RegistryPathPrefix) WithSuffix(suf string) RegistryPath {
return RegistryPath(string(p) + suf)
}
const (
IPv4TCPIPBase RegistryPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters`
IPv6TCPIPBase RegistryPath = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters`
NetBTBase RegistryPath = `SYSTEM\CurrentControlSet\Services\NetBT\Parameters`
IPv4TCPIPInterfacePrefix RegistryPathPrefix = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\`
IPv6TCPIPInterfacePrefix RegistryPathPrefix = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces\`
NetBTInterfacePrefix RegistryPathPrefix = `SYSTEM\CurrentControlSet\Services\NetBT\Parameters\Interfaces\Tcpip_`
)
// ErrKeyWaitTimeout is returned by OpenKeyWait when calls timeout.
var ErrKeyWaitTimeout = errors.New("timeout waiting for registry key")
// OpenKeyWait opens a registry key, waiting for it to appear if necessary. It
// returns the opened key, or ErrKeyWaitTimeout if the key does not appear
// within 20s. The caller must call Close on the returned key.
func OpenKeyWait(k registry.Key, path RegistryPath, access uint32) (registry.Key, error) {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
deadline := time.Now().Add(keyOpenTimeout)
pathSpl := strings.Split(string(path), "\\")
for i := 0; ; i++ {
keyName := pathSpl[i]
isLast := i+1 == len(pathSpl)
event, err := windows.CreateEvent(nil, 0, 0, nil)
if err != nil {
return 0, fmt.Errorf("windows.CreateEvent: %w", err)
}
defer windows.CloseHandle(event)
var key registry.Key
for {
err = windows.RegNotifyChangeKeyValue(windows.Handle(k), false, windows.REG_NOTIFY_CHANGE_NAME, event, true)
if err != nil {
return 0, fmt.Errorf("windows.RegNotifyChangeKeyValue: %w", err)
}
var accessFlags uint32
if isLast {
accessFlags = access
} else {
accessFlags = registry.NOTIFY
}
key, err = registry.OpenKey(k, keyName, accessFlags)
if err == windows.ERROR_FILE_NOT_FOUND || err == windows.ERROR_PATH_NOT_FOUND {
timeout := time.Until(deadline) / time.Millisecond
if timeout < 0 {
timeout = 0
}
s, err := windows.WaitForSingleObject(event, uint32(timeout))
if err != nil {
return 0, fmt.Errorf("windows.WaitForSingleObject: %w", err)
}
if s == uint32(windows.WAIT_TIMEOUT) { // windows.WAIT_TIMEOUT status const is misclassified as error in golang.org/x/sys/windows
return 0, ErrKeyWaitTimeout
}
} else if err != nil {
return 0, fmt.Errorf("registry.OpenKey(%v): %w", path, err)
} else {
if isLast {
return key, nil
}
defer key.Close()
break
}
}
k = key
}
}

Loading…
Cancel
Save