diff --git a/net/dns/manager_windows.go b/net/dns/manager_windows.go index 2d4857f19..0a82b10aa 100644 --- a/net/dns/manager_windows.go +++ b/net/dns/manager_windows.go @@ -20,27 +20,12 @@ import ( "tailscale.com/envknob" "tailscale.com/types/logger" "tailscale.com/util/dnsname" - "tailscale.com/util/winutil" ) const ( ipv4RegBase = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters` ipv6RegBase = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters` - nrptBase = `SYSTEM\CurrentControlSet\Services\Dnscache\Parameters\DnsPolicyConfig\` - nrptOverrideDNS = 0x8 // bitmask value for "use the provided override DNS resolvers" - - // This is the legacy rule ID that previous versions used when we supported - // only a single rule. Now that we support multiple rules are required, we - // generate their GUIDs and store them under the Tailscale registry key. - nrptSingleRuleID = `{5abe529b-675b-4486-8459-25a634dacc23}` - // Apparently NRPT rules cannot handle > 50 domains. - nrptMaxDomainsPerRule = 50 - - // This is the name of the registry value we use to save Rule IDs under - // the Tailscale registry key. - nrptRuleIDValueName = `NRPTRuleIDs` - versionKey = `SOFTWARE\Microsoft\Windows NT\CurrentVersion` ) @@ -49,35 +34,19 @@ var configureWSL = envknob.Bool("TS_DEBUG_CONFIGURE_WSL") type windowsManager struct { logf logger.Logf guid string - nrptWorks bool + nrptDB *nrptRuleDatabase wslManager *wslManager } -func loadRuleSubkeyNames() []string { - result := winutil.GetRegStrings(nrptRuleIDValueName, nil) - if result == nil { - // Use the legacy rule ID if none are specified in our registry key - result = []string{nrptSingleRuleID} - } - return result -} - func NewOSConfigurator(logf logger.Logf, interfaceName string) (OSConfigurator, error) { ret := windowsManager{ logf: logf, guid: interfaceName, - nrptWorks: isWindows10OrBetter(), wslManager: newWSLManager(logf), } - // Best-effort: if our NRPT rule exists, try to delete it. Unlike - // per-interface configuration, NRPT rules survive the unclean - // termination of the Tailscale process, and depending on the - // rule, it may prevent us from reaching login.tailscale.com to - // boot up. The bootstrap resolver logic will save us, but it - // slows down start-up a bunch. - if ret.nrptWorks { - ret.delAllRuleKeys() + if isWindows10OrBetter() { + ret.nrptDB = newNRPTRuleDatabase(logf) } // Log WSL status once at startup. @@ -108,29 +77,6 @@ func (m windowsManager) ifPath(basePath string) string { return fmt.Sprintf(`%s\Interfaces\%s`, basePath, m.guid) } -func (m windowsManager) delAllRuleKeys() error { - nrptRuleIDs := loadRuleSubkeyNames() - if err := m.delRuleKeys(nrptRuleIDs); err != nil { - return err - } - if err := winutil.DeleteRegValue(nrptRuleIDValueName); err != nil { - m.logf("Error deleting registry value %q: %v", nrptRuleIDValueName, err) - return err - } - return nil -} - -func (m windowsManager) delRuleKeys(nrptRuleIDs []string) error { - for _, rid := range nrptRuleIDs { - keyName := nrptBase + rid - if err := registry.DeleteKey(registry.LOCAL_MACHINE, keyName); err != nil && err != registry.ErrNotExist { - m.logf("Error deleting NRPT rule key %q: %v", keyName, err) - return err - } - } - return nil -} - func delValue(key registry.Key, name string) error { if err := key.DeleteValue(name); err != nil && err != registry.ErrNotExist { return err @@ -144,8 +90,17 @@ func delValue(key registry.Key, name string) error { // // If no resolvers are provided, the Tailscale NRPT rules are deleted. func (m windowsManager) setSplitDNS(resolvers []netaddr.IP, domains []dnsname.FQDN) error { + if m.nrptDB == nil { + if resolvers == nil { + // Just a no-op in this case. + return nil + } + return fmt.Errorf("Split DNS unsupported on this Windows version") + } + + defer m.nrptDB.Refresh() if len(resolvers) == 0 { - return m.delAllRuleKeys() + return m.nrptDB.DelAllRuleKeys() } servers := make([]string, 0, len(resolvers)) @@ -153,92 +108,7 @@ func (m windowsManager) setSplitDNS(resolvers []netaddr.IP, domains []dnsname.FQ servers = append(servers, resolver.String()) } - // NRPT has an undocumented restriction that each rule may only be associated - // with a maximum of 50 domains. If we are setting rules for more domains - // than that, we need to split domains into chunks and write out a rule per chunk. - dq := len(domains) / nrptMaxDomainsPerRule - dr := len(domains) % nrptMaxDomainsPerRule - - domainRulesLen := dq - if dr > 0 { - domainRulesLen++ - } - - nrptRuleIDs := loadRuleSubkeyNames() - for len(nrptRuleIDs) < domainRulesLen { - guid, err := windows.GenerateGUID() - if err != nil { - return err - } - nrptRuleIDs = append(nrptRuleIDs, guid.String()) - } - - // Remove any surplus rules that are no longer needed. - ruleIDsToRemove := nrptRuleIDs[domainRulesLen:] - m.delRuleKeys(ruleIDsToRemove) - - // We need to save the list of rule IDs to our Tailscale registry key so that - // we know which rules are ours during subsequent modifications to NRPT rules. - ruleIDsToWrite := nrptRuleIDs[:domainRulesLen] - if len(ruleIDsToWrite) > 0 { - if err := winutil.SetRegStrings(nrptRuleIDValueName, ruleIDsToWrite); err != nil { - return err - } - } else { - if err := winutil.DeleteRegValue(nrptRuleIDValueName); err != nil { - return err - } - } - - doms := make([]string, 0, nrptMaxDomainsPerRule) - for i := 0; i < domainRulesLen; i++ { - // Each iteration consumes nrptMaxDomainsPerRule domains... - curLen := nrptMaxDomainsPerRule - // Except for the final iteration: when we have a remainder, use that instead. - if i == domainRulesLen-1 && dr > 0 { - curLen = dr - } - - // Obtain the slice of domains to consume within the current iteration. - start := i * nrptMaxDomainsPerRule - end := start + curLen - for _, domain := range domains[start:end] { - // NRPT rules must have a leading dot, which is not usual for - // DNS search paths. - doms = append(doms, "."+domain.WithoutTrailingDot()) - } - - if err := writeNRPTRule(nrptRuleIDs[i], doms, servers); err != nil { - return err - } - - doms = doms[:0] - } - - return nil -} - -func writeNRPTRule(ruleID string, doms, servers []string) error { - // CreateKey is actually open-or-create, which suits us fine. - key, _, err := registry.CreateKey(registry.LOCAL_MACHINE, nrptBase+ruleID, registry.SET_VALUE) - if err != nil { - return fmt.Errorf("opening %s: %w", nrptBase+ruleID, err) - } - defer key.Close() - if err := key.SetDWordValue("Version", 1); err != nil { - return err - } - if err := key.SetStringsValue("Name", doms); err != nil { - return err - } - if err := key.SetStringValue("GenericDNSServers", strings.Join(servers, "; ")); err != nil { - return err - } - if err := key.SetDWordValue("ConfigOptions", nrptOverrideDNS); err != nil { - return err - } - - return nil + return m.nrptDB.WriteSplitDNSConfig(servers, domains) } // setPrimaryDNS sets the given resolvers and domains as the Tailscale @@ -352,7 +222,7 @@ func (m windowsManager) SetDNS(cfg OSConfig) error { if err := m.setPrimaryDNS(cfg.Nameservers, cfg.SearchDomains); err != nil { return err } - } else if !m.nrptWorks { + } else if m.nrptDB == nil { return errors.New("cannot set per-domain resolvers on Windows 7") } else { if err := m.setSplitDNS(cfg.Nameservers, cfg.MatchDomains); err != nil { @@ -418,7 +288,7 @@ func (m windowsManager) SetDNS(cfg OSConfig) error { } func (m windowsManager) SupportsSplitDNS() bool { - return m.nrptWorks + return m.nrptDB != nil } func (m windowsManager) Close() error { diff --git a/net/dns/manager_windows_test.go b/net/dns/manager_windows_test.go index 820785a4b..ac9370587 100644 --- a/net/dns/manager_windows_test.go +++ b/net/dns/manager_windows_test.go @@ -5,6 +5,7 @@ package dns import ( + "fmt" "math/rand" "strings" "testing" @@ -17,11 +18,62 @@ import ( "tailscale.com/util/winutil" ) -func TestManagerWindows(t *testing.T) { - if !winutil.IsCurrentProcessElevated() { - t.Skipf("test requires running as elevated user") +const testGPRuleID = "{7B1B6151-84E6-41A3-8967-62F7F7B45687}" + +var ( + procRegisterGPNotification = libUserenv.NewProc("RegisterGPNotification") + procUnregisterGPNotification = libUserenv.NewProc("UnregisterGPNotification") +) + +func TestManagerWindowsLocal(t *testing.T) { + if !isWindows10OrBetter() || !winutil.IsCurrentProcessElevated() { + t.Skipf("test requires running as elevated user on Windows 10+") + } + + runTest(t, true) +} + +func TestManagerWindowsGP(t *testing.T) { + if !isWindows10OrBetter() || !winutil.IsCurrentProcessElevated() { + t.Skipf("test requires running as elevated user on Windows 10+") + } + + checkGPNotificationsWork(t) + + // Make sure group policy is refreshed before this test exits but after we've + // cleaned everything else up. + defer procRefreshPolicyEx.Call(uintptr(1), uintptr(_RP_FORCE)) + + err := createFakeGPKey() + if err != nil { + t.Fatalf("Creating fake GP key: %v\n", err) } + defer deleteFakeGPKey(t) + runTest(t, false) +} + +func checkGPNotificationsWork(t *testing.T) { + // Test to ensure that RegisterGPNotification work on this machine, + // otherwise this test will fail. + trk, err := newGPNotificationTracker() + if err != nil { + t.Skipf("newGPNotificationTracker error: %v\n", err) + } + defer trk.Close() + + r, _, err := procRefreshPolicyEx.Call(uintptr(1), uintptr(_RP_FORCE)) + if r == 0 { + t.Fatalf("RefreshPolicyEx error: %v\n", err) + } + + timeout := uint32(10000) // Milliseconds + if !trk.DidRefreshTimeout(timeout) { + t.Skipf("GP notifications are not working on this machine\n") + } +} + +func runTest(t *testing.T, isLocal bool) { logf := func(format string, args ...any) { t.Logf(format, args...) } @@ -37,6 +89,11 @@ func TestManagerWindows(t *testing.T) { } mgr := cfg.(windowsManager) + usingGP := mgr.nrptDB.writeAsGP + if isLocal == usingGP { + t.Fatalf("usingGP %v, want %v\n", usingGP, !usingGP) + } + // Upon initialization of cfg, we should not have any NRPT rules ensureNoRules(t) @@ -74,14 +131,61 @@ func TestManagerWindows(t *testing.T) { 51, } - for _, n := range cases { + var regBaseValidate string + var regBaseEnsure string + if isLocal { + regBaseValidate = nrptBaseLocal + regBaseEnsure = nrptBaseGP + } else { + regBaseValidate = nrptBaseGP + regBaseEnsure = nrptBaseLocal + } + + var trk *gpNotificationTracker + if isLocal { + // (dblohm7) When isLocal == true, we keep trk active through the entire + // sequence of test cases, and then we verify that no policy notifications + // occurred. Because policy notifications are scoped to the entire computer, + // this check could potentially fail if another process concurrently modifies + // group policies while this test is running. I don't expect this to be an + // issue on any computer on which we run this test, but something to keep in + // mind if we start seeing flakiness around these GP notifications. + trk, err = newGPNotificationTracker() + if err != nil { + t.Fatalf("newGPNotificationTracker: %v\n", err) + } + defer trk.Close() + } + + runCase := func(n int) { t.Logf("Test case: %d domains\n", n) + if !isLocal { + // When !isLocal, we want to check that a GP notification occured for + // every single test case. + trk, err = newGPNotificationTracker() + if err != nil { + t.Fatalf("newGPNotificationTracker: %v\n", err) + } + defer trk.Close() + } caseDomains := domains[:n] - err := mgr.setSplitDNS(resolvers, caseDomains) + err = mgr.setSplitDNS(resolvers, caseDomains) if err != nil { t.Fatalf("setSplitDNS: %v\n", err) } - validateRegistry(t, caseDomains) + validateRegistry(t, regBaseValidate, caseDomains) + ensureNoRulesInSubkey(t, regBaseEnsure) + if !isLocal && !trk.DidRefresh(true) { + t.Fatalf("DidRefresh false, want true\n") + } + } + + for _, n := range cases { + runCase(n) + } + + if isLocal && trk.DidRefresh(false) { + t.Errorf("DidRefresh true, want false\n") } t.Logf("Test case: nil resolver\n") @@ -92,23 +196,92 @@ func TestManagerWindows(t *testing.T) { ensureNoRules(t) } +func createFakeGPKey() error { + keyStr := nrptBaseGP + `\` + testGPRuleID + key, _, err := registry.CreateKey(registry.LOCAL_MACHINE, keyStr, registry.SET_VALUE) + if err != nil { + return fmt.Errorf("opening %s: %w", keyStr, err) + } + defer key.Close() + if err := key.SetDWordValue("Version", 1); err != nil { + return err + } + if err := key.SetStringsValue("Name", []string{"._setbygp_.example.com"}); err != nil { + return err + } + if err := key.SetStringValue("GenericDNSServers", "1.1.1.1"); err != nil { + return err + } + if err := key.SetDWordValue("ConfigOptions", nrptOverrideDNS); err != nil { + return err + } + return nil +} + +func deleteFakeGPKey(t *testing.T) { + keyName := nrptBaseGP + `\` + testGPRuleID + if err := registry.DeleteKey(registry.LOCAL_MACHINE, keyName); err != nil && err != registry.ErrNotExist { + t.Fatalf("Error deleting NRPT rule key %q: %v\n", keyName, err) + } + + isEmpty, err := isPolicyConfigSubkeyEmpty() + if err != nil { + t.Fatalf("isPolicyConfigSubkeyEmpty: %v", err) + } + + if !isEmpty { + return + } + + if err := registry.DeleteKey(registry.LOCAL_MACHINE, nrptBaseGP); err != nil { + t.Fatalf("Deleting DnsPolicyKey Subkey: %v", err) + } +} + func ensureNoRules(t *testing.T) { ruleIDs := winutil.GetRegStrings(nrptRuleIDValueName, nil) if ruleIDs != nil { t.Errorf("%s: %v, want nil\n", nrptRuleIDValueName, ruleIDs) } - legacyKeyPath := nrptBase + nrptSingleRuleID - key, err := registry.OpenKey(registry.LOCAL_MACHINE, legacyKeyPath, registry.READ) + for _, base := range []string{nrptBaseLocal, nrptBaseGP} { + ensureNoSingleRule(t, base) + } +} + +func ensureNoRulesInSubkey(t *testing.T, base string) { + ruleIDs := winutil.GetRegStrings(nrptRuleIDValueName, nil) + if ruleIDs == nil { + for _, base := range []string{nrptBaseLocal, nrptBaseGP} { + ensureNoSingleRule(t, base) + } + return + } + + for _, ruleID := range ruleIDs { + keyName := base + `\` + ruleID + key, err := registry.OpenKey(registry.LOCAL_MACHINE, keyName, registry.READ) + if err == nil { + key.Close() + } + if err != registry.ErrNotExist { + t.Fatalf("%s: %q, want %q\n", keyName, err, registry.ErrNotExist) + } + } +} + +func ensureNoSingleRule(t *testing.T, base string) { + singleKeyPath := base + `\` + nrptSingleRuleID + key, err := registry.OpenKey(registry.LOCAL_MACHINE, singleKeyPath, registry.READ) if err == nil { key.Close() } if err != registry.ErrNotExist { - t.Errorf("%s: %q, want %q\n", legacyKeyPath, err, registry.ErrNotExist) + t.Fatalf("%s: %q, want %q\n", singleKeyPath, err, registry.ErrNotExist) } } -func validateRegistry(t *testing.T, domains []dnsname.FQDN) { +func validateRegistry(t *testing.T, nrptBase string, domains []dnsname.FQDN) { q := len(domains) / nrptMaxDomainsPerRule r := len(domains) % nrptMaxDomainsPerRule numRules := q @@ -124,9 +297,9 @@ func validateRegistry(t *testing.T, domains []dnsname.FQDN) { } for i, ruleID := range ruleIDs { - savedDomains, err := getSavedDomainsForRule(ruleID) + savedDomains, err := getSavedDomainsForRule(nrptBase, ruleID) if err != nil { - t.Fatalf("getSavedDomainsForRule(%q): %v\n", ruleID, err) + t.Fatalf("getSavedDomainsForRule(%q, %q): %v\n", nrptBase, ruleID, err) } start := i * nrptMaxDomainsPerRule @@ -148,8 +321,8 @@ func validateRegistry(t *testing.T, domains []dnsname.FQDN) { } } -func getSavedDomainsForRule(ruleID string) ([]string, error) { - keyPath := nrptBase + ruleID +func getSavedDomainsForRule(base, ruleID string) ([]string, error) { + keyPath := base + `\` + ruleID key, err := registry.OpenKey(registry.LOCAL_MACHINE, keyPath, registry.READ) if err != nil { return nil, err @@ -158,3 +331,56 @@ func getSavedDomainsForRule(ruleID string) ([]string, error) { result, _, err := key.GetStringsValue("Name") return result, err } + +// gpNotificationTracker registers with the Windows policy engine and receives +// notifications when policy refreshes occur. +type gpNotificationTracker struct { + event windows.Handle +} + +func newGPNotificationTracker() (*gpNotificationTracker, error) { + var err error + evt, err := windows.CreateEvent(nil, 0, 0, nil) + if err != nil { + return nil, err + } + defer func() { + if err != nil { + windows.CloseHandle(evt) + } + }() + + ok, _, e := procRegisterGPNotification.Call( + uintptr(evt), + uintptr(1), // We want computer policy changes, not user policy changes. + ) + if ok == 0 { + err = e + return nil, err + } + + return &gpNotificationTracker{evt}, nil +} + +func (trk *gpNotificationTracker) DidRefresh(isExpected bool) bool { + // If we're not expecting a refresh event, then we need to use a timeout. + timeout := uint32(1000) // 1 second (in milliseconds) + if isExpected { + // Otherwise, since it is imperative that we see an event, we wait infinitely. + timeout = windows.INFINITE + } + + return trk.DidRefreshTimeout(timeout) +} + +func (trk *gpNotificationTracker) DidRefreshTimeout(timeout uint32) bool { + waitCode, _ := windows.WaitForSingleObject(trk.event, timeout) + return waitCode == windows.WAIT_OBJECT_0 +} + +func (trk *gpNotificationTracker) Close() error { + procUnregisterGPNotification.Call(uintptr(trk.event)) + windows.CloseHandle(trk.event) + trk.event = 0 + return nil +} diff --git a/net/dns/nrpt_windows.go b/net/dns/nrpt_windows.go new file mode 100644 index 000000000..ceef38107 --- /dev/null +++ b/net/dns/nrpt_windows.go @@ -0,0 +1,331 @@ +// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dns + +import ( + "fmt" + "strings" + + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/registry" + "tailscale.com/types/logger" + "tailscale.com/util/dnsname" + "tailscale.com/util/winutil" +) + +const ( + dnsBaseGP = `SOFTWARE\Policies\Microsoft\Windows NT\DNSClient` + nrptBaseLocal = `SYSTEM\CurrentControlSet\Services\Dnscache\Parameters\DnsPolicyConfig` + nrptBaseGP = `SOFTWARE\Policies\Microsoft\Windows NT\DNSClient\DnsPolicyConfig` + + nrptOverrideDNS = 0x8 // bitmask value for "use the provided override DNS resolvers" + + // Apparently NRPT rules cannot handle > 50 domains. + nrptMaxDomainsPerRule = 50 + + // This is the legacy rule ID that previous versions used when we supported + // only a single rule. Now that we support multiple rules are required, we + // generate their GUIDs and store them under the Tailscale registry key. + nrptSingleRuleID = `{5abe529b-675b-4486-8459-25a634dacc23}` + + // This is the name of the registry value we use to save Rule IDs under + // the Tailscale registry key. + nrptRuleIDValueName = `NRPTRuleIDs` +) + +var ( + libUserenv = windows.NewLazySystemDLL("userenv.dll") + procRefreshPolicyEx = libUserenv.NewProc("RefreshPolicyEx") +) + +const _RP_FORCE = 1 // Flag for RefreshPolicyEx + +// nrptRuleDatabase ensapsulates access to the Windows Name Resolution Policy +// Table (NRPT). +type nrptRuleDatabase struct { + logf logger.Logf + ruleIDs []string + writeAsGP bool + isGPDirty bool +} + +func newNRPTRuleDatabase(logf logger.Logf) *nrptRuleDatabase { + ret := &nrptRuleDatabase{logf: logf} + ret.loadRuleSubkeyNames() + ret.initWriteAsGP() + logf("nrptRuleDatabase using group policy: %v\n", ret.writeAsGP) + // Best-effort: if our NRPT rule exists, try to delete it. Unlike + // per-interface configuration, NRPT rules survive the unclean + // termination of the Tailscale process, and depending on the + // rule, it may prevent us from reaching login.tailscale.com to + // boot up. The bootstrap resolver logic will save us, but it + // slows down start-up a bunch. + ret.DelAllRuleKeys() + return ret +} + +func (db *nrptRuleDatabase) loadRuleSubkeyNames() { + result := winutil.GetRegStrings(nrptRuleIDValueName, nil) + if result == nil { + // Use the legacy rule ID if none are specified in our registry key + result = []string{nrptSingleRuleID} + } + db.ruleIDs = result +} + +// initWriteAsGP determines which registry path should be used for writing +// NRPT rules. If there are rules in the GP path that don't belong to us, then +// we should use the GP path. +func (db *nrptRuleDatabase) initWriteAsGP() { + var err error + defer func() { + if err != nil { + db.writeAsGP = false + } + }() + + dnsKey, err := registry.OpenKey(registry.LOCAL_MACHINE, dnsBaseGP, registry.READ) + if err != nil { + db.logf("Failed to open key %q with error: %v\n", dnsBaseGP, err) + return + } + defer dnsKey.Close() + + ki, err := dnsKey.Stat() + if err != nil { + db.logf("Failed to stat key %q with error: %v\n", dnsBaseGP, err) + return + } + + // If the dnsKey contains any values, then we need to use the GP key. + if ki.ValueCount > 0 { + db.writeAsGP = true + return + } + + if ki.SubKeyCount == 0 { + // If dnsKey contains no values and no subkeys, then we definitely don't + // need to use the GP key. + db.writeAsGP = false + return + } + + // Get a list of all the NRPT rules under the GP subkey. + nrptKey, err := registry.OpenKey(registry.LOCAL_MACHINE, nrptBaseGP, registry.READ) + if err != nil { + db.logf("Failed to open key %q with error: %v\n", nrptBaseGP, err) + return + } + defer nrptKey.Close() + + gpSubkeyNames, err := nrptKey.ReadSubKeyNames(0) + if err != nil { + db.logf("Failed to list subkeys under %q with error: %v\n", nrptBaseGP, err) + return + } + + // Add *all* rules from the GP subkey into a set. + gpSubkeyMap := make(map[string]struct{}, len(gpSubkeyNames)) + for _, gpSubkey := range gpSubkeyNames { + gpSubkeyMap[strings.ToUpper(gpSubkey)] = struct{}{} + } + + // Remove *our* rules from the set. + for _, ourRuleID := range db.ruleIDs { + delete(gpSubkeyMap, strings.ToUpper(ourRuleID)) + } + + // Any leftover rules do not belong to us. When group policy is being used + // by something else, we must also use the GP path. + db.writeAsGP = len(gpSubkeyMap) > 0 +} + +// DelAllRuleKeys removes any and all NRPT rules that are owned by Tailscale. +func (db *nrptRuleDatabase) DelAllRuleKeys() error { + if err := db.delRuleKeys(db.ruleIDs); err != nil { + return err + } + if err := winutil.DeleteRegValue(nrptRuleIDValueName); err != nil { + db.logf("Error deleting registry value %q: %v", nrptRuleIDValueName, err) + return err + } + db.ruleIDs = nil + return nil +} + +// delRuleKeys removes the NRPT rules specified by nrptRuleIDs from the +// Windows registry. It attempts to remove the rules from both possible registry +// keys: the local key and the group policy key. +func (db *nrptRuleDatabase) delRuleKeys(nrptRuleIDs []string) error { + for _, rid := range nrptRuleIDs { + keyNameLocal := nrptBaseLocal + `\` + rid + if err := registry.DeleteKey(registry.LOCAL_MACHINE, keyNameLocal); err != nil && err != registry.ErrNotExist { + db.logf("Error deleting NRPT rule key %q: %v", keyNameLocal, err) + return err + } + + keyNameGP := nrptBaseGP + `\` + rid + err := registry.DeleteKey(registry.LOCAL_MACHINE, keyNameGP) + if err == nil { + // If this deleted subkey existed under the GP key, we will need to refresh. + db.isGPDirty = true + } else if err != registry.ErrNotExist { + db.logf("Error deleting NRPT rule key %q: %v", keyNameGP, err) + return err + } + } + + if !db.isGPDirty { + return nil + } + + // If we've removed keys from the Group Policy subkey, and the DNSPolicyConfig + // subkey is now empty, we need to remove that subkey. + isEmpty, err := isPolicyConfigSubkeyEmpty() + if err != nil || !isEmpty { + return err + } + + return registry.DeleteKey(registry.LOCAL_MACHINE, nrptBaseGP) +} + +// isPolicyConfigSubkeyEmpty returns true if and only if the nrptBaseGP exists +// and does not contain any values or subkeys. +func isPolicyConfigSubkeyEmpty() (bool, error) { + subKey, err := registry.OpenKey(registry.LOCAL_MACHINE, nrptBaseGP, registry.READ) + if err != nil { + if err == registry.ErrNotExist { + return false, nil + } + return false, err + } + defer subKey.Close() + + ki, err := subKey.Stat() + if err != nil { + return false, err + } + + return (ki.ValueCount == 0 && ki.SubKeyCount == 0), nil +} + +func (db *nrptRuleDatabase) WriteSplitDNSConfig(servers []string, domains []dnsname.FQDN) error { + // NRPT has an undocumented restriction that each rule may only be associated + // with a maximum of 50 domains. If we are setting rules for more domains + // than that, we need to split domains into chunks and write out a rule per chunk. + dq := len(domains) / nrptMaxDomainsPerRule + dr := len(domains) % nrptMaxDomainsPerRule + + domainRulesLen := dq + if dr > 0 { + domainRulesLen++ + } + + db.loadRuleSubkeyNames() + for len(db.ruleIDs) < domainRulesLen { + guid, err := windows.GenerateGUID() + if err != nil { + return err + } + db.ruleIDs = append(db.ruleIDs, guid.String()) + } + + // Remove any surplus rules that are no longer needed. + ruleIDsToRemove := db.ruleIDs[domainRulesLen:] + db.delRuleKeys(ruleIDsToRemove) + + // We need to save the list of rule IDs to our Tailscale registry key so that + // we know which rules are ours during subsequent modifications to NRPT rules. + ruleIDsToWrite := db.ruleIDs[:domainRulesLen] + if len(ruleIDsToWrite) == 0 { + if err := winutil.DeleteRegValue(nrptRuleIDValueName); err != nil { + return err + } + db.ruleIDs = nil + return nil + } + + if err := winutil.SetRegStrings(nrptRuleIDValueName, ruleIDsToWrite); err != nil { + return err + } + db.ruleIDs = ruleIDsToWrite + + curRuleID := 0 + doms := make([]string, 0, nrptMaxDomainsPerRule) + + for _, domain := range domains { + if len(doms) == nrptMaxDomainsPerRule { + if err := db.writeNRPTRule(db.ruleIDs[curRuleID], servers, doms); err != nil { + return err + } + curRuleID++ + doms = doms[:0] + } + + // NRPT rules must have a leading dot, which is not usual for + // DNS search paths. + doms = append(doms, "."+domain.WithoutTrailingDot()) + } + + if len(doms) > 0 { + if err := db.writeNRPTRule(db.ruleIDs[curRuleID], servers, doms); err != nil { + return err + } + } + + return nil +} + +// Refresh notifies the Windows group policy engine when policies have changed. +func (db *nrptRuleDatabase) Refresh() { + if !db.isGPDirty { + return + } + ok, _, err := procRefreshPolicyEx.Call( + uintptr(1), // Win32 TRUE: Refresh computer policy, not user policy. + uintptr(_RP_FORCE), + ) + if ok == 0 { + db.logf("RefreshPolicyEx failed: %v", err) + return + } + db.isGPDirty = false +} + +func (db *nrptRuleDatabase) writeNRPTRule(ruleID string, servers, doms []string) error { + var nrptBase string + if db.writeAsGP { + nrptBase = nrptBaseGP + } else { + nrptBase = nrptBaseLocal + } + + keyStr := nrptBase + `\` + ruleID + + // CreateKey is actually open-or-create, which suits us fine. + key, _, err := registry.CreateKey(registry.LOCAL_MACHINE, keyStr, registry.SET_VALUE) + if err != nil { + return fmt.Errorf("opening %s: %w", keyStr, err) + } + defer key.Close() + if err := key.SetDWordValue("Version", 1); err != nil { + return err + } + if err := key.SetStringsValue("Name", doms); err != nil { + return err + } + if err := key.SetStringValue("GenericDNSServers", strings.Join(servers, "; ")); err != nil { + return err + } + if err := key.SetDWordValue("ConfigOptions", nrptOverrideDNS); err != nil { + return err + } + + if db.writeAsGP { + db.isGPDirty = true + } + + return nil +}