From 8cdfd12977da037617de694c2a070ddddef4280f Mon Sep 17 00:00:00 2001 From: Aaron Klotz Date: Wed, 1 Jun 2022 14:41:11 -0600 Subject: [PATCH] net/dns: update Windows split DNS settings to work alongside other NRPT entries set by group policy. When there are group policy entries for the NRPT that do not belong to Tailscale, we recognize that we need to add ourselves to group policy and use that registry key instead of the local one. We also refresh the group policy settings as necessary to ensure that our changes take effect immediately. Fixes https://github.com/tailscale/tailscale/issues/4607 Signed-off-by: Aaron Klotz --- net/dns/manager_windows.go | 162 ++-------------- net/dns/manager_windows_test.go | 254 ++++++++++++++++++++++-- net/dns/nrpt_windows.go | 331 ++++++++++++++++++++++++++++++++ 3 files changed, 587 insertions(+), 160 deletions(-) create mode 100644 net/dns/nrpt_windows.go diff --git a/net/dns/manager_windows.go b/net/dns/manager_windows.go index 2d4857f19..0a82b10aa 100644 --- a/net/dns/manager_windows.go +++ b/net/dns/manager_windows.go @@ -20,27 +20,12 @@ import ( "tailscale.com/envknob" "tailscale.com/types/logger" "tailscale.com/util/dnsname" - "tailscale.com/util/winutil" ) const ( ipv4RegBase = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters` ipv6RegBase = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters` - nrptBase = `SYSTEM\CurrentControlSet\Services\Dnscache\Parameters\DnsPolicyConfig\` - nrptOverrideDNS = 0x8 // bitmask value for "use the provided override DNS resolvers" - - // This is the legacy rule ID that previous versions used when we supported - // only a single rule. Now that we support multiple rules are required, we - // generate their GUIDs and store them under the Tailscale registry key. - nrptSingleRuleID = `{5abe529b-675b-4486-8459-25a634dacc23}` - // Apparently NRPT rules cannot handle > 50 domains. - nrptMaxDomainsPerRule = 50 - - // This is the name of the registry value we use to save Rule IDs under - // the Tailscale registry key. - nrptRuleIDValueName = `NRPTRuleIDs` - versionKey = `SOFTWARE\Microsoft\Windows NT\CurrentVersion` ) @@ -49,35 +34,19 @@ var configureWSL = envknob.Bool("TS_DEBUG_CONFIGURE_WSL") type windowsManager struct { logf logger.Logf guid string - nrptWorks bool + nrptDB *nrptRuleDatabase wslManager *wslManager } -func loadRuleSubkeyNames() []string { - result := winutil.GetRegStrings(nrptRuleIDValueName, nil) - if result == nil { - // Use the legacy rule ID if none are specified in our registry key - result = []string{nrptSingleRuleID} - } - return result -} - func NewOSConfigurator(logf logger.Logf, interfaceName string) (OSConfigurator, error) { ret := windowsManager{ logf: logf, guid: interfaceName, - nrptWorks: isWindows10OrBetter(), wslManager: newWSLManager(logf), } - // Best-effort: if our NRPT rule exists, try to delete it. Unlike - // per-interface configuration, NRPT rules survive the unclean - // termination of the Tailscale process, and depending on the - // rule, it may prevent us from reaching login.tailscale.com to - // boot up. The bootstrap resolver logic will save us, but it - // slows down start-up a bunch. - if ret.nrptWorks { - ret.delAllRuleKeys() + if isWindows10OrBetter() { + ret.nrptDB = newNRPTRuleDatabase(logf) } // Log WSL status once at startup. @@ -108,29 +77,6 @@ func (m windowsManager) ifPath(basePath string) string { return fmt.Sprintf(`%s\Interfaces\%s`, basePath, m.guid) } -func (m windowsManager) delAllRuleKeys() error { - nrptRuleIDs := loadRuleSubkeyNames() - if err := m.delRuleKeys(nrptRuleIDs); err != nil { - return err - } - if err := winutil.DeleteRegValue(nrptRuleIDValueName); err != nil { - m.logf("Error deleting registry value %q: %v", nrptRuleIDValueName, err) - return err - } - return nil -} - -func (m windowsManager) delRuleKeys(nrptRuleIDs []string) error { - for _, rid := range nrptRuleIDs { - keyName := nrptBase + rid - if err := registry.DeleteKey(registry.LOCAL_MACHINE, keyName); err != nil && err != registry.ErrNotExist { - m.logf("Error deleting NRPT rule key %q: %v", keyName, err) - return err - } - } - return nil -} - func delValue(key registry.Key, name string) error { if err := key.DeleteValue(name); err != nil && err != registry.ErrNotExist { return err @@ -144,8 +90,17 @@ func delValue(key registry.Key, name string) error { // // If no resolvers are provided, the Tailscale NRPT rules are deleted. func (m windowsManager) setSplitDNS(resolvers []netaddr.IP, domains []dnsname.FQDN) error { + if m.nrptDB == nil { + if resolvers == nil { + // Just a no-op in this case. + return nil + } + return fmt.Errorf("Split DNS unsupported on this Windows version") + } + + defer m.nrptDB.Refresh() if len(resolvers) == 0 { - return m.delAllRuleKeys() + return m.nrptDB.DelAllRuleKeys() } servers := make([]string, 0, len(resolvers)) @@ -153,92 +108,7 @@ func (m windowsManager) setSplitDNS(resolvers []netaddr.IP, domains []dnsname.FQ servers = append(servers, resolver.String()) } - // NRPT has an undocumented restriction that each rule may only be associated - // with a maximum of 50 domains. If we are setting rules for more domains - // than that, we need to split domains into chunks and write out a rule per chunk. - dq := len(domains) / nrptMaxDomainsPerRule - dr := len(domains) % nrptMaxDomainsPerRule - - domainRulesLen := dq - if dr > 0 { - domainRulesLen++ - } - - nrptRuleIDs := loadRuleSubkeyNames() - for len(nrptRuleIDs) < domainRulesLen { - guid, err := windows.GenerateGUID() - if err != nil { - return err - } - nrptRuleIDs = append(nrptRuleIDs, guid.String()) - } - - // Remove any surplus rules that are no longer needed. - ruleIDsToRemove := nrptRuleIDs[domainRulesLen:] - m.delRuleKeys(ruleIDsToRemove) - - // We need to save the list of rule IDs to our Tailscale registry key so that - // we know which rules are ours during subsequent modifications to NRPT rules. - ruleIDsToWrite := nrptRuleIDs[:domainRulesLen] - if len(ruleIDsToWrite) > 0 { - if err := winutil.SetRegStrings(nrptRuleIDValueName, ruleIDsToWrite); err != nil { - return err - } - } else { - if err := winutil.DeleteRegValue(nrptRuleIDValueName); err != nil { - return err - } - } - - doms := make([]string, 0, nrptMaxDomainsPerRule) - for i := 0; i < domainRulesLen; i++ { - // Each iteration consumes nrptMaxDomainsPerRule domains... - curLen := nrptMaxDomainsPerRule - // Except for the final iteration: when we have a remainder, use that instead. - if i == domainRulesLen-1 && dr > 0 { - curLen = dr - } - - // Obtain the slice of domains to consume within the current iteration. - start := i * nrptMaxDomainsPerRule - end := start + curLen - for _, domain := range domains[start:end] { - // NRPT rules must have a leading dot, which is not usual for - // DNS search paths. - doms = append(doms, "."+domain.WithoutTrailingDot()) - } - - if err := writeNRPTRule(nrptRuleIDs[i], doms, servers); err != nil { - return err - } - - doms = doms[:0] - } - - return nil -} - -func writeNRPTRule(ruleID string, doms, servers []string) error { - // CreateKey is actually open-or-create, which suits us fine. - key, _, err := registry.CreateKey(registry.LOCAL_MACHINE, nrptBase+ruleID, registry.SET_VALUE) - if err != nil { - return fmt.Errorf("opening %s: %w", nrptBase+ruleID, err) - } - defer key.Close() - if err := key.SetDWordValue("Version", 1); err != nil { - return err - } - if err := key.SetStringsValue("Name", doms); err != nil { - return err - } - if err := key.SetStringValue("GenericDNSServers", strings.Join(servers, "; ")); err != nil { - return err - } - if err := key.SetDWordValue("ConfigOptions", nrptOverrideDNS); err != nil { - return err - } - - return nil + return m.nrptDB.WriteSplitDNSConfig(servers, domains) } // setPrimaryDNS sets the given resolvers and domains as the Tailscale @@ -352,7 +222,7 @@ func (m windowsManager) SetDNS(cfg OSConfig) error { if err := m.setPrimaryDNS(cfg.Nameservers, cfg.SearchDomains); err != nil { return err } - } else if !m.nrptWorks { + } else if m.nrptDB == nil { return errors.New("cannot set per-domain resolvers on Windows 7") } else { if err := m.setSplitDNS(cfg.Nameservers, cfg.MatchDomains); err != nil { @@ -418,7 +288,7 @@ func (m windowsManager) SetDNS(cfg OSConfig) error { } func (m windowsManager) SupportsSplitDNS() bool { - return m.nrptWorks + return m.nrptDB != nil } func (m windowsManager) Close() error { diff --git a/net/dns/manager_windows_test.go b/net/dns/manager_windows_test.go index 820785a4b..ac9370587 100644 --- a/net/dns/manager_windows_test.go +++ b/net/dns/manager_windows_test.go @@ -5,6 +5,7 @@ package dns import ( + "fmt" "math/rand" "strings" "testing" @@ -17,11 +18,62 @@ import ( "tailscale.com/util/winutil" ) -func TestManagerWindows(t *testing.T) { - if !winutil.IsCurrentProcessElevated() { - t.Skipf("test requires running as elevated user") +const testGPRuleID = "{7B1B6151-84E6-41A3-8967-62F7F7B45687}" + +var ( + procRegisterGPNotification = libUserenv.NewProc("RegisterGPNotification") + procUnregisterGPNotification = libUserenv.NewProc("UnregisterGPNotification") +) + +func TestManagerWindowsLocal(t *testing.T) { + if !isWindows10OrBetter() || !winutil.IsCurrentProcessElevated() { + t.Skipf("test requires running as elevated user on Windows 10+") + } + + runTest(t, true) +} + +func TestManagerWindowsGP(t *testing.T) { + if !isWindows10OrBetter() || !winutil.IsCurrentProcessElevated() { + t.Skipf("test requires running as elevated user on Windows 10+") + } + + checkGPNotificationsWork(t) + + // Make sure group policy is refreshed before this test exits but after we've + // cleaned everything else up. + defer procRefreshPolicyEx.Call(uintptr(1), uintptr(_RP_FORCE)) + + err := createFakeGPKey() + if err != nil { + t.Fatalf("Creating fake GP key: %v\n", err) } + defer deleteFakeGPKey(t) + runTest(t, false) +} + +func checkGPNotificationsWork(t *testing.T) { + // Test to ensure that RegisterGPNotification work on this machine, + // otherwise this test will fail. + trk, err := newGPNotificationTracker() + if err != nil { + t.Skipf("newGPNotificationTracker error: %v\n", err) + } + defer trk.Close() + + r, _, err := procRefreshPolicyEx.Call(uintptr(1), uintptr(_RP_FORCE)) + if r == 0 { + t.Fatalf("RefreshPolicyEx error: %v\n", err) + } + + timeout := uint32(10000) // Milliseconds + if !trk.DidRefreshTimeout(timeout) { + t.Skipf("GP notifications are not working on this machine\n") + } +} + +func runTest(t *testing.T, isLocal bool) { logf := func(format string, args ...any) { t.Logf(format, args...) } @@ -37,6 +89,11 @@ func TestManagerWindows(t *testing.T) { } mgr := cfg.(windowsManager) + usingGP := mgr.nrptDB.writeAsGP + if isLocal == usingGP { + t.Fatalf("usingGP %v, want %v\n", usingGP, !usingGP) + } + // Upon initialization of cfg, we should not have any NRPT rules ensureNoRules(t) @@ -74,14 +131,61 @@ func TestManagerWindows(t *testing.T) { 51, } - for _, n := range cases { + var regBaseValidate string + var regBaseEnsure string + if isLocal { + regBaseValidate = nrptBaseLocal + regBaseEnsure = nrptBaseGP + } else { + regBaseValidate = nrptBaseGP + regBaseEnsure = nrptBaseLocal + } + + var trk *gpNotificationTracker + if isLocal { + // (dblohm7) When isLocal == true, we keep trk active through the entire + // sequence of test cases, and then we verify that no policy notifications + // occurred. Because policy notifications are scoped to the entire computer, + // this check could potentially fail if another process concurrently modifies + // group policies while this test is running. I don't expect this to be an + // issue on any computer on which we run this test, but something to keep in + // mind if we start seeing flakiness around these GP notifications. + trk, err = newGPNotificationTracker() + if err != nil { + t.Fatalf("newGPNotificationTracker: %v\n", err) + } + defer trk.Close() + } + + runCase := func(n int) { t.Logf("Test case: %d domains\n", n) + if !isLocal { + // When !isLocal, we want to check that a GP notification occured for + // every single test case. + trk, err = newGPNotificationTracker() + if err != nil { + t.Fatalf("newGPNotificationTracker: %v\n", err) + } + defer trk.Close() + } caseDomains := domains[:n] - err := mgr.setSplitDNS(resolvers, caseDomains) + err = mgr.setSplitDNS(resolvers, caseDomains) if err != nil { t.Fatalf("setSplitDNS: %v\n", err) } - validateRegistry(t, caseDomains) + validateRegistry(t, regBaseValidate, caseDomains) + ensureNoRulesInSubkey(t, regBaseEnsure) + if !isLocal && !trk.DidRefresh(true) { + t.Fatalf("DidRefresh false, want true\n") + } + } + + for _, n := range cases { + runCase(n) + } + + if isLocal && trk.DidRefresh(false) { + t.Errorf("DidRefresh true, want false\n") } t.Logf("Test case: nil resolver\n") @@ -92,23 +196,92 @@ func TestManagerWindows(t *testing.T) { ensureNoRules(t) } +func createFakeGPKey() error { + keyStr := nrptBaseGP + `\` + testGPRuleID + key, _, err := registry.CreateKey(registry.LOCAL_MACHINE, keyStr, registry.SET_VALUE) + if err != nil { + return fmt.Errorf("opening %s: %w", keyStr, err) + } + defer key.Close() + if err := key.SetDWordValue("Version", 1); err != nil { + return err + } + if err := key.SetStringsValue("Name", []string{"._setbygp_.example.com"}); err != nil { + return err + } + if err := key.SetStringValue("GenericDNSServers", "1.1.1.1"); err != nil { + return err + } + if err := key.SetDWordValue("ConfigOptions", nrptOverrideDNS); err != nil { + return err + } + return nil +} + +func deleteFakeGPKey(t *testing.T) { + keyName := nrptBaseGP + `\` + testGPRuleID + if err := registry.DeleteKey(registry.LOCAL_MACHINE, keyName); err != nil && err != registry.ErrNotExist { + t.Fatalf("Error deleting NRPT rule key %q: %v\n", keyName, err) + } + + isEmpty, err := isPolicyConfigSubkeyEmpty() + if err != nil { + t.Fatalf("isPolicyConfigSubkeyEmpty: %v", err) + } + + if !isEmpty { + return + } + + if err := registry.DeleteKey(registry.LOCAL_MACHINE, nrptBaseGP); err != nil { + t.Fatalf("Deleting DnsPolicyKey Subkey: %v", err) + } +} + func ensureNoRules(t *testing.T) { ruleIDs := winutil.GetRegStrings(nrptRuleIDValueName, nil) if ruleIDs != nil { t.Errorf("%s: %v, want nil\n", nrptRuleIDValueName, ruleIDs) } - legacyKeyPath := nrptBase + nrptSingleRuleID - key, err := registry.OpenKey(registry.LOCAL_MACHINE, legacyKeyPath, registry.READ) + for _, base := range []string{nrptBaseLocal, nrptBaseGP} { + ensureNoSingleRule(t, base) + } +} + +func ensureNoRulesInSubkey(t *testing.T, base string) { + ruleIDs := winutil.GetRegStrings(nrptRuleIDValueName, nil) + if ruleIDs == nil { + for _, base := range []string{nrptBaseLocal, nrptBaseGP} { + ensureNoSingleRule(t, base) + } + return + } + + for _, ruleID := range ruleIDs { + keyName := base + `\` + ruleID + key, err := registry.OpenKey(registry.LOCAL_MACHINE, keyName, registry.READ) + if err == nil { + key.Close() + } + if err != registry.ErrNotExist { + t.Fatalf("%s: %q, want %q\n", keyName, err, registry.ErrNotExist) + } + } +} + +func ensureNoSingleRule(t *testing.T, base string) { + singleKeyPath := base + `\` + nrptSingleRuleID + key, err := registry.OpenKey(registry.LOCAL_MACHINE, singleKeyPath, registry.READ) if err == nil { key.Close() } if err != registry.ErrNotExist { - t.Errorf("%s: %q, want %q\n", legacyKeyPath, err, registry.ErrNotExist) + t.Fatalf("%s: %q, want %q\n", singleKeyPath, err, registry.ErrNotExist) } } -func validateRegistry(t *testing.T, domains []dnsname.FQDN) { +func validateRegistry(t *testing.T, nrptBase string, domains []dnsname.FQDN) { q := len(domains) / nrptMaxDomainsPerRule r := len(domains) % nrptMaxDomainsPerRule numRules := q @@ -124,9 +297,9 @@ func validateRegistry(t *testing.T, domains []dnsname.FQDN) { } for i, ruleID := range ruleIDs { - savedDomains, err := getSavedDomainsForRule(ruleID) + savedDomains, err := getSavedDomainsForRule(nrptBase, ruleID) if err != nil { - t.Fatalf("getSavedDomainsForRule(%q): %v\n", ruleID, err) + t.Fatalf("getSavedDomainsForRule(%q, %q): %v\n", nrptBase, ruleID, err) } start := i * nrptMaxDomainsPerRule @@ -148,8 +321,8 @@ func validateRegistry(t *testing.T, domains []dnsname.FQDN) { } } -func getSavedDomainsForRule(ruleID string) ([]string, error) { - keyPath := nrptBase + ruleID +func getSavedDomainsForRule(base, ruleID string) ([]string, error) { + keyPath := base + `\` + ruleID key, err := registry.OpenKey(registry.LOCAL_MACHINE, keyPath, registry.READ) if err != nil { return nil, err @@ -158,3 +331,56 @@ func getSavedDomainsForRule(ruleID string) ([]string, error) { result, _, err := key.GetStringsValue("Name") return result, err } + +// gpNotificationTracker registers with the Windows policy engine and receives +// notifications when policy refreshes occur. +type gpNotificationTracker struct { + event windows.Handle +} + +func newGPNotificationTracker() (*gpNotificationTracker, error) { + var err error + evt, err := windows.CreateEvent(nil, 0, 0, nil) + if err != nil { + return nil, err + } + defer func() { + if err != nil { + windows.CloseHandle(evt) + } + }() + + ok, _, e := procRegisterGPNotification.Call( + uintptr(evt), + uintptr(1), // We want computer policy changes, not user policy changes. + ) + if ok == 0 { + err = e + return nil, err + } + + return &gpNotificationTracker{evt}, nil +} + +func (trk *gpNotificationTracker) DidRefresh(isExpected bool) bool { + // If we're not expecting a refresh event, then we need to use a timeout. + timeout := uint32(1000) // 1 second (in milliseconds) + if isExpected { + // Otherwise, since it is imperative that we see an event, we wait infinitely. + timeout = windows.INFINITE + } + + return trk.DidRefreshTimeout(timeout) +} + +func (trk *gpNotificationTracker) DidRefreshTimeout(timeout uint32) bool { + waitCode, _ := windows.WaitForSingleObject(trk.event, timeout) + return waitCode == windows.WAIT_OBJECT_0 +} + +func (trk *gpNotificationTracker) Close() error { + procUnregisterGPNotification.Call(uintptr(trk.event)) + windows.CloseHandle(trk.event) + trk.event = 0 + return nil +} diff --git a/net/dns/nrpt_windows.go b/net/dns/nrpt_windows.go new file mode 100644 index 000000000..ceef38107 --- /dev/null +++ b/net/dns/nrpt_windows.go @@ -0,0 +1,331 @@ +// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dns + +import ( + "fmt" + "strings" + + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/registry" + "tailscale.com/types/logger" + "tailscale.com/util/dnsname" + "tailscale.com/util/winutil" +) + +const ( + dnsBaseGP = `SOFTWARE\Policies\Microsoft\Windows NT\DNSClient` + nrptBaseLocal = `SYSTEM\CurrentControlSet\Services\Dnscache\Parameters\DnsPolicyConfig` + nrptBaseGP = `SOFTWARE\Policies\Microsoft\Windows NT\DNSClient\DnsPolicyConfig` + + nrptOverrideDNS = 0x8 // bitmask value for "use the provided override DNS resolvers" + + // Apparently NRPT rules cannot handle > 50 domains. + nrptMaxDomainsPerRule = 50 + + // This is the legacy rule ID that previous versions used when we supported + // only a single rule. Now that we support multiple rules are required, we + // generate their GUIDs and store them under the Tailscale registry key. + nrptSingleRuleID = `{5abe529b-675b-4486-8459-25a634dacc23}` + + // This is the name of the registry value we use to save Rule IDs under + // the Tailscale registry key. + nrptRuleIDValueName = `NRPTRuleIDs` +) + +var ( + libUserenv = windows.NewLazySystemDLL("userenv.dll") + procRefreshPolicyEx = libUserenv.NewProc("RefreshPolicyEx") +) + +const _RP_FORCE = 1 // Flag for RefreshPolicyEx + +// nrptRuleDatabase ensapsulates access to the Windows Name Resolution Policy +// Table (NRPT). +type nrptRuleDatabase struct { + logf logger.Logf + ruleIDs []string + writeAsGP bool + isGPDirty bool +} + +func newNRPTRuleDatabase(logf logger.Logf) *nrptRuleDatabase { + ret := &nrptRuleDatabase{logf: logf} + ret.loadRuleSubkeyNames() + ret.initWriteAsGP() + logf("nrptRuleDatabase using group policy: %v\n", ret.writeAsGP) + // Best-effort: if our NRPT rule exists, try to delete it. Unlike + // per-interface configuration, NRPT rules survive the unclean + // termination of the Tailscale process, and depending on the + // rule, it may prevent us from reaching login.tailscale.com to + // boot up. The bootstrap resolver logic will save us, but it + // slows down start-up a bunch. + ret.DelAllRuleKeys() + return ret +} + +func (db *nrptRuleDatabase) loadRuleSubkeyNames() { + result := winutil.GetRegStrings(nrptRuleIDValueName, nil) + if result == nil { + // Use the legacy rule ID if none are specified in our registry key + result = []string{nrptSingleRuleID} + } + db.ruleIDs = result +} + +// initWriteAsGP determines which registry path should be used for writing +// NRPT rules. If there are rules in the GP path that don't belong to us, then +// we should use the GP path. +func (db *nrptRuleDatabase) initWriteAsGP() { + var err error + defer func() { + if err != nil { + db.writeAsGP = false + } + }() + + dnsKey, err := registry.OpenKey(registry.LOCAL_MACHINE, dnsBaseGP, registry.READ) + if err != nil { + db.logf("Failed to open key %q with error: %v\n", dnsBaseGP, err) + return + } + defer dnsKey.Close() + + ki, err := dnsKey.Stat() + if err != nil { + db.logf("Failed to stat key %q with error: %v\n", dnsBaseGP, err) + return + } + + // If the dnsKey contains any values, then we need to use the GP key. + if ki.ValueCount > 0 { + db.writeAsGP = true + return + } + + if ki.SubKeyCount == 0 { + // If dnsKey contains no values and no subkeys, then we definitely don't + // need to use the GP key. + db.writeAsGP = false + return + } + + // Get a list of all the NRPT rules under the GP subkey. + nrptKey, err := registry.OpenKey(registry.LOCAL_MACHINE, nrptBaseGP, registry.READ) + if err != nil { + db.logf("Failed to open key %q with error: %v\n", nrptBaseGP, err) + return + } + defer nrptKey.Close() + + gpSubkeyNames, err := nrptKey.ReadSubKeyNames(0) + if err != nil { + db.logf("Failed to list subkeys under %q with error: %v\n", nrptBaseGP, err) + return + } + + // Add *all* rules from the GP subkey into a set. + gpSubkeyMap := make(map[string]struct{}, len(gpSubkeyNames)) + for _, gpSubkey := range gpSubkeyNames { + gpSubkeyMap[strings.ToUpper(gpSubkey)] = struct{}{} + } + + // Remove *our* rules from the set. + for _, ourRuleID := range db.ruleIDs { + delete(gpSubkeyMap, strings.ToUpper(ourRuleID)) + } + + // Any leftover rules do not belong to us. When group policy is being used + // by something else, we must also use the GP path. + db.writeAsGP = len(gpSubkeyMap) > 0 +} + +// DelAllRuleKeys removes any and all NRPT rules that are owned by Tailscale. +func (db *nrptRuleDatabase) DelAllRuleKeys() error { + if err := db.delRuleKeys(db.ruleIDs); err != nil { + return err + } + if err := winutil.DeleteRegValue(nrptRuleIDValueName); err != nil { + db.logf("Error deleting registry value %q: %v", nrptRuleIDValueName, err) + return err + } + db.ruleIDs = nil + return nil +} + +// delRuleKeys removes the NRPT rules specified by nrptRuleIDs from the +// Windows registry. It attempts to remove the rules from both possible registry +// keys: the local key and the group policy key. +func (db *nrptRuleDatabase) delRuleKeys(nrptRuleIDs []string) error { + for _, rid := range nrptRuleIDs { + keyNameLocal := nrptBaseLocal + `\` + rid + if err := registry.DeleteKey(registry.LOCAL_MACHINE, keyNameLocal); err != nil && err != registry.ErrNotExist { + db.logf("Error deleting NRPT rule key %q: %v", keyNameLocal, err) + return err + } + + keyNameGP := nrptBaseGP + `\` + rid + err := registry.DeleteKey(registry.LOCAL_MACHINE, keyNameGP) + if err == nil { + // If this deleted subkey existed under the GP key, we will need to refresh. + db.isGPDirty = true + } else if err != registry.ErrNotExist { + db.logf("Error deleting NRPT rule key %q: %v", keyNameGP, err) + return err + } + } + + if !db.isGPDirty { + return nil + } + + // If we've removed keys from the Group Policy subkey, and the DNSPolicyConfig + // subkey is now empty, we need to remove that subkey. + isEmpty, err := isPolicyConfigSubkeyEmpty() + if err != nil || !isEmpty { + return err + } + + return registry.DeleteKey(registry.LOCAL_MACHINE, nrptBaseGP) +} + +// isPolicyConfigSubkeyEmpty returns true if and only if the nrptBaseGP exists +// and does not contain any values or subkeys. +func isPolicyConfigSubkeyEmpty() (bool, error) { + subKey, err := registry.OpenKey(registry.LOCAL_MACHINE, nrptBaseGP, registry.READ) + if err != nil { + if err == registry.ErrNotExist { + return false, nil + } + return false, err + } + defer subKey.Close() + + ki, err := subKey.Stat() + if err != nil { + return false, err + } + + return (ki.ValueCount == 0 && ki.SubKeyCount == 0), nil +} + +func (db *nrptRuleDatabase) WriteSplitDNSConfig(servers []string, domains []dnsname.FQDN) error { + // NRPT has an undocumented restriction that each rule may only be associated + // with a maximum of 50 domains. If we are setting rules for more domains + // than that, we need to split domains into chunks and write out a rule per chunk. + dq := len(domains) / nrptMaxDomainsPerRule + dr := len(domains) % nrptMaxDomainsPerRule + + domainRulesLen := dq + if dr > 0 { + domainRulesLen++ + } + + db.loadRuleSubkeyNames() + for len(db.ruleIDs) < domainRulesLen { + guid, err := windows.GenerateGUID() + if err != nil { + return err + } + db.ruleIDs = append(db.ruleIDs, guid.String()) + } + + // Remove any surplus rules that are no longer needed. + ruleIDsToRemove := db.ruleIDs[domainRulesLen:] + db.delRuleKeys(ruleIDsToRemove) + + // We need to save the list of rule IDs to our Tailscale registry key so that + // we know which rules are ours during subsequent modifications to NRPT rules. + ruleIDsToWrite := db.ruleIDs[:domainRulesLen] + if len(ruleIDsToWrite) == 0 { + if err := winutil.DeleteRegValue(nrptRuleIDValueName); err != nil { + return err + } + db.ruleIDs = nil + return nil + } + + if err := winutil.SetRegStrings(nrptRuleIDValueName, ruleIDsToWrite); err != nil { + return err + } + db.ruleIDs = ruleIDsToWrite + + curRuleID := 0 + doms := make([]string, 0, nrptMaxDomainsPerRule) + + for _, domain := range domains { + if len(doms) == nrptMaxDomainsPerRule { + if err := db.writeNRPTRule(db.ruleIDs[curRuleID], servers, doms); err != nil { + return err + } + curRuleID++ + doms = doms[:0] + } + + // NRPT rules must have a leading dot, which is not usual for + // DNS search paths. + doms = append(doms, "."+domain.WithoutTrailingDot()) + } + + if len(doms) > 0 { + if err := db.writeNRPTRule(db.ruleIDs[curRuleID], servers, doms); err != nil { + return err + } + } + + return nil +} + +// Refresh notifies the Windows group policy engine when policies have changed. +func (db *nrptRuleDatabase) Refresh() { + if !db.isGPDirty { + return + } + ok, _, err := procRefreshPolicyEx.Call( + uintptr(1), // Win32 TRUE: Refresh computer policy, not user policy. + uintptr(_RP_FORCE), + ) + if ok == 0 { + db.logf("RefreshPolicyEx failed: %v", err) + return + } + db.isGPDirty = false +} + +func (db *nrptRuleDatabase) writeNRPTRule(ruleID string, servers, doms []string) error { + var nrptBase string + if db.writeAsGP { + nrptBase = nrptBaseGP + } else { + nrptBase = nrptBaseLocal + } + + keyStr := nrptBase + `\` + ruleID + + // CreateKey is actually open-or-create, which suits us fine. + key, _, err := registry.CreateKey(registry.LOCAL_MACHINE, keyStr, registry.SET_VALUE) + if err != nil { + return fmt.Errorf("opening %s: %w", keyStr, err) + } + defer key.Close() + if err := key.SetDWordValue("Version", 1); err != nil { + return err + } + if err := key.SetStringsValue("Name", doms); err != nil { + return err + } + if err := key.SetStringValue("GenericDNSServers", strings.Join(servers, "; ")); err != nil { + return err + } + if err := key.SetDWordValue("ConfigOptions", nrptOverrideDNS); err != nil { + return err + } + + if db.writeAsGP { + db.isGPDirty = true + } + + return nil +}