diff --git a/wgengine/router/dns/manager_windows.go b/wgengine/router/dns/manager_windows.go index a768b4312..1de16b5c2 100644 --- a/wgengine/router/dns/manager_windows.go +++ b/wgengine/router/dns/manager_windows.go @@ -5,9 +5,10 @@ package dns import ( - "errors" "fmt" + "os/exec" "strings" + "syscall" "github.com/tailscale/wireguard-go/tun" "golang.org/x/sys/windows/registry" @@ -46,68 +47,16 @@ func setRegistryString(path, name, value string) error { return nil } -func getRegistryString(path, name string) (string, error) { - key, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.READ) - if err != nil { - return "", fmt.Errorf("opening %s: %w", path, err) - } - defer key.Close() - - value, _, err := key.GetStringValue(name) - if err != nil { - return "", fmt.Errorf("getting %s[%s]: %w", path, name, err) - } - return value, nil -} - func (m windowsManager) setNameservers(basePath string, nameservers []string) error { path := fmt.Sprintf(`%s\Interfaces\%s`, basePath, m.guid) value := strings.Join(nameservers, ",") return setRegistryString(path, "NameServer", value) } -func (m windowsManager) setDomains(path string, oldDomains, newDomains []string) error { - // We reimplement setRegistryString to ensure that we hold the key for the whole operation. - key, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.READ|registry.SET_VALUE) - if err != nil { - return fmt.Errorf("opening %s: %w", path, err) - } - defer key.Close() - - searchList, _, err := key.GetStringValue("SearchList") - if err != nil && err != registry.ErrNotExist { - return fmt.Errorf("getting %s[SearchList]: %w", path, err) - } - currentDomains := strings.Split(searchList, ",") - - var domainsToSet []string - for _, domain := range currentDomains { - inOld, inNew := false, false - - // The number of domains should be small, - // so this is probaly faster than constructing a map. - for _, oldDomain := range oldDomains { - if domain == oldDomain { - inOld = true - } - } - for _, newDomain := range newDomains { - if domain == newDomain { - inNew = true - } - } - - if !inNew && !inOld { - domainsToSet = append(domainsToSet, domain) - } - } - domainsToSet = append(domainsToSet, newDomains...) - - searchList = strings.Join(domainsToSet, ",") - if err := key.SetStringValue("SearchList", searchList); err != nil { - return fmt.Errorf("setting %s[SearchList]: %w", path, err) - } - return nil +func (m windowsManager) setDomains(basePath string, domains []string) error { + path := fmt.Sprintf(`%s\Interfaces\%s`, basePath, m.guid) + value := strings.Join(domains, ",") + return setRegistryString(path, "SearchList", value) } func (m windowsManager) Up(config Config) error { @@ -122,23 +71,17 @@ func (m windowsManager) Up(config Config) error { } } - lastSearchList, err := getRegistryString(tsRegBase, "SearchList") - if err != nil && !errors.Is(err, registry.ErrNotExist) { - return err - } - lastDomains := strings.Split(lastSearchList, ",") - if err := m.setNameservers(ipv4RegBase, ipsv4); err != nil { return err } - if err := m.setDomains(ipv4RegBase, lastDomains, config.Domains); err != nil { + if err := m.setDomains(ipv4RegBase, config.Domains); err != nil { return err } if err := m.setNameservers(ipv6RegBase, ipsv6); err != nil { return err } - if err := m.setDomains(ipv6RegBase, lastDomains, config.Domains); err != nil { + if err := m.setDomains(ipv6RegBase, config.Domains); err != nil { return err } @@ -147,6 +90,19 @@ func (m windowsManager) Up(config Config) error { return err } + // Force DNS re-registration in Active Directory. What we actually + // care about is that this command invokes the undocumented hidden + // function that forces Windows to notice that adapter settings + // have changed, which makes the DNS settings actually take + // effect. + // + // This command can take a few seconds to run. + cmd := exec.Command("ipconfig", "/registerdns") + cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true} + if err := cmd.Run(); err != nil { + return fmt.Errorf("running ipconfig /registerdns: %w", err) + } + return nil }