diff --git a/net/dns/manager_windows_test.go b/net/dns/manager_windows_test.go index 7bf1432cf..167eae2b6 100644 --- a/net/dns/manager_windows_test.go +++ b/net/dns/manager_windows_test.go @@ -62,7 +62,7 @@ func TestManagerWindowsGP(t *testing.T) { runTest(t, false) } -func TestManagerWindowsGPMove(t *testing.T) { +func TestManagerWindowsGPCopy(t *testing.T) { if !isWindows10OrBetter() || !winutil.IsCurrentProcessElevated() { t.Skipf("test requires running as elevated user on Windows 10+") } @@ -139,10 +139,10 @@ func TestManagerWindowsGPMove(t *testing.T) { t.Fatalf("regWatcher.wait: %v\n", err) } - // 3. Check that local NRPT is empty and GP is populated + // 3. Check that both local NRPT and GP NRPT are populated t.Logf("Validating that group policy NRPT is populated...\n") + validateRegistry(t, nrptBaseLocal, domains) validateRegistry(t, nrptBaseGP, domains) - ensureNoRulesInSubkey(t, nrptBaseLocal) // 4. Delete fake GP key and refresh t.Logf("Deleting fake group policy key and refreshing...\n") @@ -578,25 +578,11 @@ func (trk *gpNotificationTracker) Close() error { } type regKeyWatcher struct { - keyLocal registry.Key - keyGP registry.Key - evtLocal windows.Handle - evtGP windows.Handle + keyGP registry.Key + evtGP windows.Handle } -func newRegKeyWatcher() (*regKeyWatcher, error) { - var err error - - keyLocal, _, err := registry.CreateKey(registry.LOCAL_MACHINE, nrptBaseLocal, registry.READ) - if err != nil { - return nil, err - } - defer func() { - if err != nil { - keyLocal.Close() - } - }() - +func newRegKeyWatcher() (result *regKeyWatcher, err error) { // Monitor dnsBaseGP instead of nrptBaseGP, since the latter will be // repeatedly created and destroyed throughout the course of the test. keyGP, _, err := registry.CreateKey(registry.LOCAL_MACHINE, dnsBaseGP, registry.READ) @@ -609,58 +595,31 @@ func newRegKeyWatcher() (*regKeyWatcher, error) { } }() - evtLocal, err := windows.CreateEvent(nil, 0, 0, nil) - if err != nil { - return nil, err - } - defer func() { - if err != nil { - windows.CloseHandle(evtLocal) - } - }() - evtGP, err := windows.CreateEvent(nil, 0, 0, nil) if err != nil { return nil, err } - result := ®KeyWatcher{ - keyLocal: keyLocal, - keyGP: keyGP, - evtLocal: evtLocal, - evtGP: evtGP, - } - - return result, nil + return ®KeyWatcher{ + keyGP: keyGP, + evtGP: evtGP, + }, nil } func (rw *regKeyWatcher) watch() error { // We can make these waits thread-agnostic because the tests that use this code must already run on Windows 10+ - err := windows.RegNotifyChangeKeyValue(windows.Handle(rw.keyLocal), true, - windows.REG_NOTIFY_CHANGE_NAME|windows.REG_NOTIFY_THREAD_AGNOSTIC, rw.evtLocal, true) - if err != nil { - return err - } - return windows.RegNotifyChangeKeyValue(windows.Handle(rw.keyGP), true, windows.REG_NOTIFY_CHANGE_NAME|windows.REG_NOTIFY_THREAD_AGNOSTIC, rw.evtGP, true) } func (rw *regKeyWatcher) wait() error { - handles := []windows.Handle{ - rw.evtLocal, + waitCode, err := windows.WaitForSingleObject( rw.evtGP, - } - - waitCode, err := windows.WaitForMultipleObjects( - handles, - true, // Wait for both events to signal before resuming. 10000, // 10 seconds (as milliseconds) ) - const WAIT_TIMEOUT = 0x102 switch waitCode { - case WAIT_TIMEOUT: + case uint32(windows.WAIT_TIMEOUT): return context.DeadlineExceeded case windows.WAIT_FAILED: return err @@ -670,9 +629,7 @@ func (rw *regKeyWatcher) wait() error { } func (rw *regKeyWatcher) Close() error { - rw.keyLocal.Close() rw.keyGP.Close() - windows.CloseHandle(rw.evtLocal) windows.CloseHandle(rw.evtGP) return nil } diff --git a/net/dns/nrpt_windows.go b/net/dns/nrpt_windows.go index 78a702616..06dbe7806 100644 --- a/net/dns/nrpt_windows.go +++ b/net/dns/nrpt_windows.go @@ -86,12 +86,8 @@ func newNRPTRuleDatabase(logf logger.Logf) *nrptRuleDatabase { } func (db *nrptRuleDatabase) loadRuleSubkeyNames() { - result := winutil.GetRegStrings(nrptRuleIDValueName, nil) - if result == nil { - // Use the legacy rule ID if none are specified in our registry key - result = []string{nrptSingleRuleID} - } - db.ruleIDs = result + // Use the legacy rule ID if none are specified in our registry key + db.ruleIDs = winutil.GetRegStrings(nrptRuleIDValueName, []string{nrptSingleRuleID}) } // detectWriteAsGP determines which registry path should be used for writing @@ -113,41 +109,20 @@ func (db *nrptRuleDatabase) detectWriteAsGP() { db.writeAsGP = writeAsGP db.logf("nrptRuleDatabase using group policy: %v, was %v\n", writeAsGP, prev) // When db.watcher == nil, prev != writeAsGP because we're initializing, not - // because anything has changed. We do not invoke db.movePolicies in that case. + // because anything has changed. We do not invoke + // db.updateGroupPoliciesLocked in that case. if db.watcher != nil && prev != writeAsGP { - db.movePolicies(writeAsGP) + db.updateGroupPoliciesLocked(writeAsGP) } }() - dnsKey, err := registry.OpenKey(registry.LOCAL_MACHINE, dnsBaseGP, registry.READ) - if err != nil { - db.logf("Failed to open key %q with error: %v\n", dnsBaseGP, err) - return - } - defer dnsKey.Close() - - ki, err := dnsKey.Stat() - if err != nil { - db.logf("Failed to stat key %q with error: %v\n", dnsBaseGP, err) - return - } - - // If the dnsKey contains any values, then we need to use the GP key. - if ki.ValueCount > 0 { - writeAsGP = true - return - } - - if ki.SubKeyCount == 0 { - // If dnsKey contains no values and no subkeys, then we definitely don't - // need to use the GP key. - return - } - // Get a list of all the NRPT rules under the GP subkey. nrptKey, err := registry.OpenKey(registry.LOCAL_MACHINE, nrptBaseGP, registry.READ) if err != nil { - db.logf("Failed to open key %q with error: %v\n", nrptBaseGP, err) + if err != registry.ErrNotExist { + db.logf("Failed to open key %q with error: %v\n", nrptBaseGP, err) + } + // If this subkey does not exist then we definitely don't need to use the GP key. return } defer nrptKey.Close() @@ -253,14 +228,7 @@ func (db *nrptRuleDatabase) WriteSplitDNSConfig(servers []string, domains []dnsn // NRPT has an undocumented restriction that each rule may only be associated // with a maximum of 50 domains. If we are setting rules for more domains // than that, we need to split domains into chunks and write out a rule per chunk. - dq := len(domains) / nrptMaxDomainsPerRule - dr := len(domains) % nrptMaxDomainsPerRule - - domainRulesLen := dq - if dr > 0 { - domainRulesLen++ - } - + domainRulesLen := (len(domains) + nrptMaxDomainsPerRule - 1) / nrptMaxDomainsPerRule db.loadRuleSubkeyNames() for len(db.ruleIDs) < domainRulesLen { @@ -348,28 +316,26 @@ func (db *nrptRuleDatabase) refreshLocked() { } func (db *nrptRuleDatabase) writeNRPTRule(ruleID string, servers, doms []string) error { - var nrptBase string - if db.writeAsGP { - nrptBase = nrptBaseGP - } else { - nrptBase = nrptBaseLocal + subKeys := []string{nrptBaseLocal, nrptBaseGP} + if !db.writeAsGP { + // We don't want to write to the GP key, so chop nrptBaseGP off of subKeys. + subKeys = subKeys[:1] } - keyStr := nrptBase + `\` + ruleID - - // CreateKey is actually open-or-create, which suits us fine. - key, _, err := registry.CreateKey(registry.LOCAL_MACHINE, keyStr, registry.SET_VALUE) - if err != nil { - return fmt.Errorf("opening %s: %w", keyStr, err) - } - defer key.Close() + for _, subKeyBase := range subKeys { + subKey := strings.Join([]string{subKeyBase, ruleID}, `\`) + key, _, err := registry.CreateKey(registry.LOCAL_MACHINE, subKey, registry.SET_VALUE) + if err != nil { + return fmt.Errorf("opening %q: %w", subKey, err) + } + defer key.Close() - if err := writeNRPTValues(key, strings.Join(servers, "; "), doms); err != nil { - return err + if err := writeNRPTValues(key, strings.Join(servers, "; "), doms); err != nil { + return err + } } db.isGPDirty = db.writeAsGP - return nil } @@ -400,8 +366,6 @@ func writeNRPTValues(key registry.Key, servers string, doms []string) error { } func (db *nrptRuleDatabase) watchForGPChanges() { - db.isGPRefreshPending.Store(false) - watchHandler := func() { // Do not invoke detectWriteAsGP when we ourselves were responsible for // initiating the group policy refresh. @@ -420,34 +384,29 @@ func (db *nrptRuleDatabase) watchForGPChanges() { db.watcher = watcher } -// movePolicies moves each NRPT rule depending on the value of writeAsGP. -// When writeAsGP is true, each NRPT rule is moved from the local NRPT table -// to the group policy NRPT table. When writeAsGP is false, the move is -// executed in the opposite direction. db.mu should already be locked. -func (db *nrptRuleDatabase) movePolicies(writeAsGP bool) { - // Since we're moving either in or out of the group policy NRPT table, we need - // to refresh once this movePolicies is done. +// updateGroupPoliciesLocked updates the NRPT group policy table depending on +// the value of writeAsGP. When writeAsGP is true, each NRPT rule is copied from +// the local NRPT table to the group policy NRPT table. When writeAsGP is false, +// we remove any Tailscale NRPT rules from the group policy table and, if no +// non-Tailscale rules remain, we also delete the entire DnsPolicyConfig subkey. +// db.mu must already be locked. +func (db *nrptRuleDatabase) updateGroupPoliciesLocked(writeAsGP bool) { + // Since we're updating the group policy NRPT table, we need + // to refresh once this updateGroupPoliciesLocked is done. defer db.refreshLocked() - var fromBase string - var toBase string - if writeAsGP { - fromBase = nrptBaseLocal - toBase = nrptBaseGP - } else { - fromBase = nrptBaseGP - toBase = nrptBaseLocal - } - fromBase += `\` - toBase += `\` - for _, id := range db.ruleIDs { - fromStr := fromBase + id - toStr := toBase + id - - if err := executeMove(fromStr, toStr); err != nil { - db.logf("movePolicies: executeMove(\"%s\", \"%s\") failed with error %v", fromStr, toStr, err) - return + if writeAsGP { + if err := copyNRPTRule(id); err != nil { + db.logf("updateGroupPoliciesLocked: copyNRPTRule(%q) failed with error %v", id, err) + return + } + } else { + subKeyFrom := strings.Join([]string{nrptBaseGP, id}, `\`) + if err := registry.DeleteKey(registry.LOCAL_MACHINE, subKeyFrom); err != nil && err != registry.ErrNotExist { + db.logf("updateGroupPoliciesLocked: DeleteKey for rule %q failed with error %v", id, err) + return + } } db.isGPDirty = true @@ -457,55 +416,49 @@ func (db *nrptRuleDatabase) movePolicies(writeAsGP bool) { return } - // Now that we have moved our rules out of the group policy subkey, it should + // Now that we have removed our rules from group policy subkey, it should // now be empty. Let's verify that. isEmpty, err := isPolicyConfigSubkeyEmpty() if err != nil { - db.logf("movePolicies: isPolicyConfigSubkeyEmpty error %v", err) + db.logf("updateGroupPoliciesLocked: isPolicyConfigSubkeyEmpty error %v", err) return } if !isEmpty { - db.logf("movePolicies: policy config subkey should be empty, but isn't!") + db.logf("updateGroupPoliciesLocked: policy config subkey should be empty, but isn't!") return } // Delete the subkey itself. Group policy will continue to override local // settings unless we do so. if err := registry.DeleteKey(registry.LOCAL_MACHINE, nrptBaseGP); err != nil { - db.logf("movePolicies DeleteKey error %v", err) + db.logf("updateGroupPoliciesLocked DeleteKey error %v", err) } db.isGPDirty = true } -func executeMove(subKeyFrom, subKeyTo string) error { - err := func() error { - // Move the NRPT registry values from subKeyFrom to subKeyTo. - fromKey, err := registry.OpenKey(registry.LOCAL_MACHINE, subKeyFrom, registry.QUERY_VALUE) - if err != nil { - return err - } - defer fromKey.Close() +func copyNRPTRule(ruleID string) error { + subKeyFrom := strings.Join([]string{nrptBaseLocal, ruleID}, `\`) + subKeyTo := strings.Join([]string{nrptBaseGP, ruleID}, `\`) - toKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, subKeyTo, registry.WRITE) - if err != nil { - return err - } - defer toKey.Close() + fromKey, err := registry.OpenKey(registry.LOCAL_MACHINE, subKeyFrom, registry.QUERY_VALUE) + if err != nil { + return err + } + defer fromKey.Close() - servers, doms, err := readNRPTValues(fromKey) - if err != nil { - return err - } + toKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, subKeyTo, registry.WRITE) + if err != nil { + return err + } + defer toKey.Close() - return writeNRPTValues(toKey, servers, doms) - }() + servers, doms, err := readNRPTValues(fromKey) if err != nil { return err } - // This is a move operation, so we must delete subKeyFrom. - return registry.DeleteKey(registry.LOCAL_MACHINE, subKeyFrom) + return writeNRPTValues(toKey, servers, doms) } func (db *nrptRuleDatabase) Close() error {