net/dns: add Windows group policy notifications to the NRPT rule manager

As discussed in previous PRs, we can register for notifications when group
policies are updated and act accordingly.

This patch changes nrptRuleDatabase to receive notifications that group policy
has changed and automatically move our NRPT rules between the local and
group policy subkeys as needed.

Signed-off-by: Aaron Klotz <aaron@tailscale.com>
pull/5155/head
Aaron Klotz 2 years ago
parent f17873e0f4
commit 1cae618b03

@ -370,8 +370,8 @@ type dnsTCPSession struct {
conn net.Conn
srcAddr netaddr.IPPort
readClosing chan struct{}
responses chan []byte // DNS replies pending writing
readClosing chan struct{}
responses chan []byte // DNS replies pending writing
ctx context.Context
closeCtx context.CancelFunc
@ -457,11 +457,11 @@ func (s *dnsTCPSession) handleReads() {
// servicing DNS requests sent down it.
func (m *Manager) HandleTCPConn(conn net.Conn, srcAddr netaddr.IPPort) {
s := dnsTCPSession{
m: m,
conn: conn,
srcAddr: srcAddr,
responses: make(chan []byte),
readClosing: make(chan struct{}),
m: m,
conn: conn,
srcAddr: srcAddr,
responses: make(chan []byte),
readClosing: make(chan struct{}),
}
s.ctx, s.closeCtx = context.WithCancel(m.ctx)
go s.handleReads()

@ -297,7 +297,11 @@ func (m windowsManager) SupportsSplitDNS() bool {
}
func (m windowsManager) Close() error {
return m.SetDNS(OSConfig{})
err := m.SetDNS(OSConfig{})
if m.nrptDB != nil {
m.nrptDB.Close()
}
return err
}
// disableDynamicUpdates sets the appropriate registry values to prevent the

@ -5,6 +5,7 @@
package dns
import (
"context"
"fmt"
"math/rand"
"strings"
@ -20,11 +21,6 @@ import (
const testGPRuleID = "{7B1B6151-84E6-41A3-8967-62F7F7B45687}"
var (
procRegisterGPNotification = libUserenv.NewProc("RegisterGPNotification")
procUnregisterGPNotification = libUserenv.NewProc("UnregisterGPNotification")
)
func TestManagerWindowsLocal(t *testing.T) {
if !isWindows10OrBetter() || !winutil.IsCurrentProcessElevated() {
t.Skipf("test requires running as elevated user on Windows 10+")
@ -53,6 +49,121 @@ func TestManagerWindowsGP(t *testing.T) {
runTest(t, false)
}
func TestManagerWindowsGPMove(t *testing.T) {
if !isWindows10OrBetter() || !winutil.IsCurrentProcessElevated() {
t.Skipf("test requires running as elevated user on Windows 10+")
}
checkGPNotificationsWork(t)
logf := func(format string, args ...any) {
t.Logf(format, args...)
}
fakeInterface, err := windows.GenerateGUID()
if err != nil {
t.Fatalf("windows.GenerateGUID: %v\n", err)
}
delIfKey, err := createFakeInterfaceKey(t, fakeInterface)
if err != nil {
t.Fatalf("createFakeInterfaceKey: %v\n", err)
}
defer delIfKey()
cfg, err := NewOSConfigurator(logf, fakeInterface.String())
if err != nil {
t.Fatalf("NewOSConfigurator: %v\n", err)
}
mgr := cfg.(windowsManager)
defer mgr.Close()
usingGP := mgr.nrptDB.writeAsGP
if usingGP {
t.Fatalf("usingGP %v, want %v\n", usingGP, false)
}
regWatcher, err := newRegKeyWatcher()
if err != nil {
t.Fatalf("newRegKeyWatcher error %v\n", err)
}
// Upon initialization of cfg, we should not have any NRPT rules
ensureNoRules(t)
resolvers := []netaddr.IP{netaddr.MustParseIP("1.1.1.1")}
domains := genRandomSubdomains(t, 1)
// 1. Populate local NRPT
err = mgr.setSplitDNS(resolvers, domains)
if err != nil {
t.Fatalf("setSplitDNS: %v\n", err)
}
t.Logf("Validating that local NRPT is populated...\n")
validateRegistry(t, nrptBaseLocal, domains)
ensureNoRulesInSubkey(t, nrptBaseGP)
// 2. Create fake GP key and refresh
t.Logf("Creating fake group policy key and refreshing...\n")
err = createFakeGPKey()
if err != nil {
t.Fatalf("createFakeGPKey: %v\n", err)
}
err = regWatcher.watch()
if err != nil {
t.Fatalf("regWatcher.watch: %v\n", err)
}
err = testDoRefresh()
if err != nil {
t.Fatalf("testDoRefresh: %v\n", err)
}
err = regWatcher.wait()
if err != nil {
t.Fatalf("regWatcher.wait: %v\n", err)
}
// 3. Check that local NRPT is empty and GP is populated
t.Logf("Validating that group policy NRPT is populated...\n")
validateRegistry(t, nrptBaseGP, domains)
ensureNoRulesInSubkey(t, nrptBaseLocal)
// 4. Delete fake GP key and refresh
t.Logf("Deleting fake group policy key and refreshing...\n")
deleteFakeGPKey(t)
err = regWatcher.watch()
if err != nil {
t.Fatalf("regWatcher.watch: %v\n", err)
}
err = testDoRefresh()
if err != nil {
t.Fatalf("testDoRefresh: %v\n", err)
}
err = regWatcher.wait()
if err != nil {
t.Fatalf("regWatcher.wait: %v\n", err)
}
// 5. Check that local NRPT is populated and GP is empty
t.Logf("Validating that local NRPT is populated...\n")
validateRegistry(t, nrptBaseLocal, domains)
ensureNoRulesInSubkey(t, nrptBaseGP)
// 6. Cleanup
t.Logf("Cleaning up...\n")
err = mgr.setSplitDNS(nil, domains)
if err != nil {
t.Fatalf("setSplitDNS: %v\n", err)
}
ensureNoRules(t)
}
func checkGPNotificationsWork(t *testing.T) {
// Test to ensure that RegisterGPNotification work on this machine,
// otherwise this test will fail.
@ -83,11 +194,18 @@ func runTest(t *testing.T, isLocal bool) {
t.Fatalf("windows.GenerateGUID: %v\n", err)
}
delIfKey, err := createFakeInterfaceKey(t, fakeInterface)
if err != nil {
t.Fatalf("createFakeInterfaceKey: %v\n", err)
}
defer delIfKey()
cfg, err := NewOSConfigurator(logf, fakeInterface.String())
if err != nil {
t.Fatalf("NewOSConfigurator: %v\n", err)
}
mgr := cfg.(windowsManager)
defer mgr.Close()
usingGP := mgr.nrptDB.writeAsGP
if isLocal == usingGP {
@ -99,25 +217,7 @@ func runTest(t *testing.T, isLocal bool) {
resolvers := []netaddr.IP{netaddr.MustParseIP("1.1.1.1")}
domains := make([]dnsname.FQDN, 0, 2*nrptMaxDomainsPerRule+1)
r := rand.New(rand.NewSource(time.Now().UnixNano()))
const charset = "abcdefghijklmnopqrstuvwxyz"
// Just generate a bunch of random subdomains
for len(domains) < cap(domains) {
l := r.Intn(19) + 1
b := make([]byte, l)
for i, _ := range b {
b[i] = charset[r.Intn(len(charset))]
}
d := string(b) + ".example.com"
fqdn, err := dnsname.ToFQDN(d)
if err != nil {
t.Fatalf("dnsname.ToFQDN: %v\n", err)
}
domains = append(domains, fqdn)
}
domains := genRandomSubdomains(t, 2*nrptMaxDomainsPerRule+1)
cases := []int{
1,
@ -238,6 +338,32 @@ func deleteFakeGPKey(t *testing.T) {
}
}
func createFakeInterfaceKey(t *testing.T, guid windows.GUID) (func(), error) {
basePaths := []string{ipv4RegBase, ipv6RegBase}
keyPaths := make([]string, 0, len(basePaths))
for _, basePath := range basePaths {
keyPath := fmt.Sprintf(`%s\Interfaces\%s`, basePath, guid)
key, _, err := registry.CreateKey(registry.LOCAL_MACHINE, keyPath, registry.SET_VALUE)
if err != nil {
return nil, err
}
key.Close()
keyPaths = append(keyPaths, keyPath)
}
result := func() {
for _, keyPath := range keyPaths {
if err := registry.DeleteKey(registry.LOCAL_MACHINE, keyPath); err != nil {
t.Fatalf("deleting fake interface key \"%s\": %v\n", keyPath, err)
}
}
}
return result, nil
}
func ensureNoRules(t *testing.T) {
ruleIDs := winutil.GetRegStrings(nrptRuleIDValueName, nil)
if ruleIDs != nil {
@ -263,11 +389,29 @@ func ensureNoRulesInSubkey(t *testing.T, base string) {
key, err := registry.OpenKey(registry.LOCAL_MACHINE, keyName, registry.READ)
if err == nil {
key.Close()
}
if err != registry.ErrNotExist {
} else if err != registry.ErrNotExist {
t.Fatalf("%s: %q, want %q\n", keyName, err, registry.ErrNotExist)
}
}
if base == nrptBaseGP {
// When dealing with the group policy subkey, we want the base key to
// also be absent.
key, err := registry.OpenKey(registry.LOCAL_MACHINE, base, registry.READ)
if err == nil {
key.Close()
isEmpty, err := isPolicyConfigSubkeyEmpty()
if err != nil {
t.Fatalf("isPolicyConfigSubkeyEmpty: %v", err)
}
if isEmpty {
t.Errorf("Unexpectedly found group policy key\n")
}
} else if err != registry.ErrNotExist {
t.Errorf("Group policy key error: %q, want %q\n", err, registry.ErrNotExist)
}
}
}
func ensureNoSingleRule(t *testing.T, base string) {
@ -332,6 +476,40 @@ func getSavedDomainsForRule(base, ruleID string) ([]string, error) {
return result, err
}
func genRandomSubdomains(t *testing.T, n int) []dnsname.FQDN {
domains := make([]dnsname.FQDN, 0, n)
seed := time.Now().UnixNano()
t.Logf("genRandomSubdomains(%d) seed: %v\n", n, seed)
r := rand.New(rand.NewSource(seed))
const charset = "abcdefghijklmnopqrstuvwxyz"
for len(domains) < cap(domains) {
l := r.Intn(19) + 1
b := make([]byte, l)
for i, _ := range b {
b[i] = charset[r.Intn(len(charset))]
}
d := string(b) + ".example.com"
fqdn, err := dnsname.ToFQDN(d)
if err != nil {
t.Fatalf("dnsname.ToFQDN: %v\n", err)
}
domains = append(domains, fqdn)
}
return domains
}
func testDoRefresh() (err error) {
r, _, e := procRefreshPolicyEx.Call(uintptr(1), uintptr(_RP_FORCE))
if r == 0 {
err = e
}
return err
}
// gpNotificationTracker registers with the Windows policy engine and receives
// notifications when policy refreshes occur.
type gpNotificationTracker struct {
@ -384,3 +562,103 @@ func (trk *gpNotificationTracker) Close() error {
trk.event = 0
return nil
}
type regKeyWatcher struct {
keyLocal registry.Key
keyGP registry.Key
evtLocal windows.Handle
evtGP windows.Handle
}
func newRegKeyWatcher() (*regKeyWatcher, error) {
var err error
keyLocal, _, err := registry.CreateKey(registry.LOCAL_MACHINE, nrptBaseLocal, registry.READ)
if err != nil {
return nil, err
}
defer func() {
if err != nil {
keyLocal.Close()
}
}()
// Monitor dnsBaseGP instead of nrptBaseGP, since the latter will be
// repeatedly created and destroyed throughout the course of the test.
keyGP, _, err := registry.CreateKey(registry.LOCAL_MACHINE, dnsBaseGP, registry.READ)
if err != nil {
return nil, err
}
defer func() {
if err != nil {
keyGP.Close()
}
}()
evtLocal, err := windows.CreateEvent(nil, 0, 0, nil)
if err != nil {
return nil, err
}
defer func() {
if err != nil {
windows.CloseHandle(evtLocal)
}
}()
evtGP, err := windows.CreateEvent(nil, 0, 0, nil)
if err != nil {
return nil, err
}
result := &regKeyWatcher{
keyLocal: keyLocal,
keyGP: keyGP,
evtLocal: evtLocal,
evtGP: evtGP,
}
return result, nil
}
func (rw *regKeyWatcher) watch() error {
// We can make these waits thread-agnostic because the tests that use this code must already run on Windows 10+
err := windows.RegNotifyChangeKeyValue(windows.Handle(rw.keyLocal), true,
windows.REG_NOTIFY_CHANGE_NAME|windows.REG_NOTIFY_THREAD_AGNOSTIC, rw.evtLocal, true)
if err != nil {
return err
}
return windows.RegNotifyChangeKeyValue(windows.Handle(rw.keyGP), true,
windows.REG_NOTIFY_CHANGE_NAME|windows.REG_NOTIFY_THREAD_AGNOSTIC, rw.evtGP, true)
}
func (rw *regKeyWatcher) wait() error {
handles := []windows.Handle{
rw.evtLocal,
rw.evtGP,
}
waitCode, err := windows.WaitForMultipleObjects(
handles,
true, // Wait for both events to signal before resuming.
10000, // 10 seconds (as milliseconds)
)
const WAIT_TIMEOUT = 0x102
switch waitCode {
case WAIT_TIMEOUT:
return context.DeadlineExceeded
case windows.WAIT_FAILED:
return err
default:
return nil
}
}
func (rw *regKeyWatcher) Close() error {
rw.keyLocal.Close()
rw.keyGP.Close()
windows.CloseHandle(rw.evtLocal)
windows.CloseHandle(rw.evtGP)
return nil
}

@ -7,6 +7,8 @@ package dns
import (
"fmt"
"strings"
"sync"
"sync/atomic"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/registry"
@ -33,11 +35,25 @@ const (
// This is the name of the registry value we use to save Rule IDs under
// the Tailscale registry key.
nrptRuleIDValueName = `NRPTRuleIDs`
// This is the name of the registry value the NRPT uses for storing a rule's version number.
nrptRuleVersionName = `Version`
// This is the name of the registry value the NRPT uses for storing a rule's list of domains.
nrptRuleDomsName = `Name`
// This is the name of the registry value the NRPT uses for storing a rule's list of DNS servers.
nrptRuleServersName = `GenericDNSServers`
// This is the name of the registry value the NRPT uses for storing a rule's flags.
nrptRuleFlagsName = `ConfigOptions`
)
var (
libUserenv = windows.NewLazySystemDLL("userenv.dll")
procRefreshPolicyEx = libUserenv.NewProc("RefreshPolicyEx")
libUserenv = windows.NewLazySystemDLL("userenv.dll")
procRefreshPolicyEx = libUserenv.NewProc("RefreshPolicyEx")
procRegisterGPNotification = libUserenv.NewProc("RegisterGPNotification")
procUnregisterGPNotification = libUserenv.NewProc("UnregisterGPNotification")
)
const _RP_FORCE = 1 // Flag for RefreshPolicyEx
@ -45,17 +61,20 @@ const _RP_FORCE = 1 // Flag for RefreshPolicyEx
// nrptRuleDatabase ensapsulates access to the Windows Name Resolution Policy
// Table (NRPT).
type nrptRuleDatabase struct {
logf logger.Logf
ruleIDs []string
writeAsGP bool
isGPDirty bool
logf logger.Logf
watcher *gpNotificationWatcher
isGPRefreshPending atomic.Value // of bool
mu sync.Mutex // protects the fields below
ruleIDs []string
isGPDirty bool
writeAsGP bool
}
func newNRPTRuleDatabase(logf logger.Logf) *nrptRuleDatabase {
ret := &nrptRuleDatabase{logf: logf}
ret.loadRuleSubkeyNames()
ret.initWriteAsGP()
logf("nrptRuleDatabase using group policy: %v\n", ret.writeAsGP)
ret.detectWriteAsGP()
ret.watchForGPChanges()
// Best-effort: if our NRPT rule exists, try to delete it. Unlike
// per-interface configuration, NRPT rules survive the unclean
// termination of the Tailscale process, and depending on the
@ -75,14 +94,28 @@ func (db *nrptRuleDatabase) loadRuleSubkeyNames() {
db.ruleIDs = result
}
// initWriteAsGP determines which registry path should be used for writing
// detectWriteAsGP determines which registry path should be used for writing
// NRPT rules. If there are rules in the GP path that don't belong to us, then
// we should use the GP path.
func (db *nrptRuleDatabase) initWriteAsGP() {
// we should use the GP path. When detectWriteAsGP determines that the desired
// path has changed, it moves the NRPT policies as appropriate.
func (db *nrptRuleDatabase) detectWriteAsGP() {
db.mu.Lock()
defer db.mu.Unlock()
writeAsGP := false
var err error
defer func() {
if err != nil {
db.writeAsGP = false
return
}
prev := db.writeAsGP
db.writeAsGP = writeAsGP
db.logf("nrptRuleDatabase using group policy: %v, was %v\n", writeAsGP, prev)
// When db.watcher == nil, prev != writeAsGP because we're initializing, not
// because anything has changed. We do not invoke db.movePolicies in that case.
if db.watcher != nil && prev != writeAsGP {
db.movePolicies(writeAsGP)
}
}()
@ -101,14 +134,13 @@ func (db *nrptRuleDatabase) initWriteAsGP() {
// If the dnsKey contains any values, then we need to use the GP key.
if ki.ValueCount > 0 {
db.writeAsGP = true
writeAsGP = true
return
}
if ki.SubKeyCount == 0 {
// If dnsKey contains no values and no subkeys, then we definitely don't
// need to use the GP key.
db.writeAsGP = false
return
}
@ -139,11 +171,14 @@ func (db *nrptRuleDatabase) initWriteAsGP() {
// Any leftover rules do not belong to us. When group policy is being used
// by something else, we must also use the GP path.
db.writeAsGP = len(gpSubkeyMap) > 0
writeAsGP = len(gpSubkeyMap) > 0
}
// DelAllRuleKeys removes any and all NRPT rules that are owned by Tailscale.
func (db *nrptRuleDatabase) DelAllRuleKeys() error {
db.mu.Lock()
defer db.mu.Unlock()
if err := db.delRuleKeys(db.ruleIDs); err != nil {
return err
}
@ -212,6 +247,9 @@ func isPolicyConfigSubkeyEmpty() (bool, error) {
}
func (db *nrptRuleDatabase) WriteSplitDNSConfig(servers []string, domains []dnsname.FQDN) error {
db.mu.Lock()
defer db.mu.Unlock()
// NRPT has an undocumented restriction that each rule may only be associated
// with a maximum of 50 domains. If we are setting rules for more domains
// than that, we need to split domains into chunks and write out a rule per chunk.
@ -224,6 +262,7 @@ func (db *nrptRuleDatabase) WriteSplitDNSConfig(servers []string, domains []dnsn
}
db.loadRuleSubkeyNames()
for len(db.ruleIDs) < domainRulesLen {
guid, err := windows.GenerateGUID()
if err != nil {
@ -280,9 +319,22 @@ func (db *nrptRuleDatabase) WriteSplitDNSConfig(servers []string, domains []dnsn
// Refresh notifies the Windows group policy engine when policies have changed.
func (db *nrptRuleDatabase) Refresh() {
db.mu.Lock()
defer db.mu.Unlock()
db.refreshLocked()
}
func (db *nrptRuleDatabase) refreshLocked() {
if !db.isGPDirty {
return
}
// Record that we are about to initiate a refresh.
// (*nrptRuleDatabase).watchForGPChanges() checks this value to avoid false
// positives.
db.isGPRefreshPending.Store(true)
ok, _, err := procRefreshPolicyEx.Call(
uintptr(1), // Win32 TRUE: Refresh computer policy, not user policy.
uintptr(_RP_FORCE),
@ -291,6 +343,7 @@ func (db *nrptRuleDatabase) Refresh() {
db.logf("RefreshPolicyEx failed: %v", err)
return
}
db.isGPDirty = false
}
@ -310,22 +363,256 @@ func (db *nrptRuleDatabase) writeNRPTRule(ruleID string, servers, doms []string)
return fmt.Errorf("opening %s: %w", keyStr, err)
}
defer key.Close()
if err := key.SetDWordValue("Version", 1); err != nil {
if err := writeNRPTValues(key, strings.Join(servers, "; "), doms); err != nil {
return err
}
if err := key.SetStringsValue("Name", doms); err != nil {
db.isGPDirty = db.writeAsGP
return nil
}
func readNRPTValues(key registry.Key) (servers string, doms []string, err error) {
doms, _, err = key.GetStringsValue(nrptRuleDomsName)
if err != nil {
return servers, doms, err
}
servers, _, err = key.GetStringValue(nrptRuleServersName)
return servers, doms, err
}
func writeNRPTValues(key registry.Key, servers string, doms []string) error {
if err := key.SetDWordValue(nrptRuleVersionName, 1); err != nil {
return err
}
if err := key.SetStringValue("GenericDNSServers", strings.Join(servers, "; ")); err != nil {
if err := key.SetStringsValue(nrptRuleDomsName, doms); err != nil {
return err
}
if err := key.SetDWordValue("ConfigOptions", nrptOverrideDNS); err != nil {
if err := key.SetStringValue(nrptRuleServersName, servers); err != nil {
return err
}
if db.writeAsGP {
return key.SetDWordValue(nrptRuleFlagsName, nrptOverrideDNS)
}
func (db *nrptRuleDatabase) watchForGPChanges() {
db.isGPRefreshPending.Store(false)
watchHandler := func() {
// Do not invoke detectWriteAsGP when we ourselves were responsible for
// initiating the group policy refresh.
if db.isGPRefreshPending.CompareAndSwap(true, false) {
return
}
db.logf("Computer group policies refreshed, reconfiguring NRPT rule database.")
db.detectWriteAsGP()
}
watcher, err := newGPNotificationWatcher(watchHandler)
if err != nil {
return
}
db.watcher = watcher
}
// movePolicies moves each NRPT rule depending on the value of writeAsGP.
// When writeAsGP is true, each NRPT rule is moved from the local NRPT table
// to the group policy NRPT table. When writeAsGP is false, the move is
// executed in the opposite direction. db.mu should already be locked.
func (db *nrptRuleDatabase) movePolicies(writeAsGP bool) {
// Since we're moving either in or out of the group policy NRPT table, we need
// to refresh once this movePolicies is done.
defer db.refreshLocked()
var fromBase string
var toBase string
if writeAsGP {
fromBase = nrptBaseLocal
toBase = nrptBaseGP
} else {
fromBase = nrptBaseGP
toBase = nrptBaseLocal
}
fromBase += `\`
toBase += `\`
for _, id := range db.ruleIDs {
fromStr := fromBase + id
toStr := toBase + id
if err := executeMove(fromStr, toStr); err != nil {
db.logf("movePolicies: executeMove(\"%s\", \"%s\") failed with error %v", fromStr, toStr, err)
return
}
db.isGPDirty = true
}
if writeAsGP {
return
}
// Now that we have moved our rules out of the group policy subkey, it should
// now be empty. Let's verify that.
isEmpty, err := isPolicyConfigSubkeyEmpty()
if err != nil {
db.logf("movePolicies: isPolicyConfigSubkeyEmpty error %v", err)
return
}
if !isEmpty {
db.logf("movePolicies: policy config subkey should be empty, but isn't!")
return
}
// Delete the subkey itself. Group policy will continue to override local
// settings unless we do so.
if err := registry.DeleteKey(registry.LOCAL_MACHINE, nrptBaseGP); err != nil {
db.logf("movePolicies DeleteKey error %v", err)
}
db.isGPDirty = true
}
func executeMove(subKeyFrom, subKeyTo string) error {
err := func() error {
// Move the NRPT registry values from subKeyFrom to subKeyTo.
fromKey, err := registry.OpenKey(registry.LOCAL_MACHINE, subKeyFrom, registry.QUERY_VALUE)
if err != nil {
return err
}
defer fromKey.Close()
toKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, subKeyTo, registry.WRITE)
if err != nil {
return err
}
defer toKey.Close()
servers, doms, err := readNRPTValues(fromKey)
if err != nil {
return err
}
return writeNRPTValues(toKey, servers, doms)
}()
if err != nil {
return err
}
// This is a move operation, so we must delete subKeyFrom.
return registry.DeleteKey(registry.LOCAL_MACHINE, subKeyFrom)
}
func (db *nrptRuleDatabase) Close() error {
if db.watcher == nil {
return nil
}
err := db.watcher.Close()
db.watcher = nil
return err
}
type gpNotificationWatcher struct {
gpWaitEvents [2]windows.Handle
handler func()
done chan struct{}
}
// newGPNotificationWatcher creates an instance of gpNotificationWatcher that
// invokes handler every time Windows notifies it of a group policy change.
func newGPNotificationWatcher(handler func()) (*gpNotificationWatcher, error) {
var err error
// evtDone is signaled by (*gpNotificationWatcher).Close() to indicate that
// the doWatch goroutine should exit.
evtDone, err := windows.CreateEvent(nil, 0, 0, nil)
if err != nil {
return nil, err
}
defer func() {
if err != nil {
windows.CloseHandle(evtDone)
}
}()
// evtChanged is registered with the Windows policy engine to become
// signalled any time group policy has been refreshed.
evtChanged, err := windows.CreateEvent(nil, 0, 0, nil)
if err != nil {
return nil, err
}
defer func() {
if err != nil {
windows.CloseHandle(evtChanged)
}
}()
// Tell Windows to signal evtChanged whenever group policies are refreshed.
ok, _, e := procRegisterGPNotification.Call(
uintptr(evtChanged),
uintptr(1), // Win32 TRUE: We want to monitor computer policy changes, not user policy changes.
)
if ok == 0 {
err = e
return nil, err
}
result := &gpNotificationWatcher{
// Ordering of the event handles in gpWaitEvents is important:
// When calling windows.WaitForMultipleObjects and multiple objects are
// signalled simultaneously, it always returns the wait code for the
// lowest-indexed handle in its input array. evtDone is higher priority for
// us than evtChanged, so the former must be placed into the array ahead of
// the latter.
gpWaitEvents: [2]windows.Handle{
evtDone,
evtChanged,
},
handler: handler,
done: make(chan struct{}),
}
go result.doWatch()
return result, nil
}
func (w *gpNotificationWatcher) doWatch() {
// The wait code corresponding to the event that is signalled when a group
// policy change occurs.
const expectedWaitCode = windows.WAIT_OBJECT_0 + 1
for {
if waitCode, _ := windows.WaitForMultipleObjects(w.gpWaitEvents[:], false, windows.INFINITE); waitCode != expectedWaitCode {
break
}
w.handler()
}
close(w.done)
}
func (w *gpNotificationWatcher) Close() error {
// Notify doWatch that we're done and it should exit.
if err := windows.SetEvent(w.gpWaitEvents[0]); err != nil {
return err
}
procUnregisterGPNotification.Call(uintptr(w.gpWaitEvents[1]))
// Wait for doWatch to complete.
<-w.done
// Now we may safely clean up all the things.
for i, evt := range w.gpWaitEvents {
windows.CloseHandle(evt)
w.gpWaitEvents[i] = 0
}
w.handler = nil
return nil
}

Loading…
Cancel
Save