/* SPDX-License-Identifier: MIT * * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. */ package router import ( "bytes" "encoding/binary" "errors" "fmt" "log" "net" "sort" "time" "unsafe" ole "github.com/go-ole/go-ole" winipcfg "github.com/tailscale/winipcfg-go" "github.com/tailscale/wireguard-go/device" "github.com/tailscale/wireguard-go/tun" "github.com/tailscale/wireguard-go/wgcfg" "golang.org/x/sys/windows" "golang.org/x/sys/windows/registry" "tailscale.com/wgengine/winnet" ) const ( sockoptIP_UNICAST_IF = 31 sockoptIPV6_UNICAST_IF = 31 ) func htonl(val uint32) uint32 { bytes := make([]byte, 4) binary.BigEndian.PutUint32(bytes, val) return *(*uint32)(unsafe.Pointer(&bytes[0])) } func bindSocketRoute(family winipcfg.AddressFamily, device *device.Device, ourLuid uint64, lastLuid *uint64) error { routes, err := winipcfg.GetRoutes(family) if err != nil { return err } lowestMetric := ^uint32(0) index := uint32(0) // Zero is "unspecified", which for IP_UNICAST_IF resets the value, which is what we want. luid := uint64(0) // Hopefully luid zero is unspecified, but hard to find docs saying so. for _, route := range routes { if route.DestinationPrefix.PrefixLength != 0 || route.InterfaceLuid == ourLuid { continue } if route.Metric < lowestMetric { lowestMetric = route.Metric index = route.InterfaceIndex luid = route.InterfaceLuid } } if luid == *lastLuid { return nil } *lastLuid = luid if false { // TODO(apenwarr): doesn't work with magic socket yet. if family == winipcfg.AF_INET { return device.BindSocketToInterface4(index, false) } else if family == winipcfg.AF_INET6 { return device.BindSocketToInterface6(index, false) } } else { log.Printf("WARNING: skipping windows socket binding.\n") } return nil } func monitorDefaultRoutes(device *device.Device, autoMTU bool, tun *tun.NativeTun) (*winipcfg.RouteChangeCallback, error) { guid := tun.GUID() ourLuid, err := winipcfg.InterfaceGuidToLuid(&guid) lastLuid4 := uint64(0) lastLuid6 := uint64(0) lastMtu := uint32(0) if err != nil { return nil, err } doIt := func() error { err = bindSocketRoute(winipcfg.AF_INET, device, ourLuid, &lastLuid4) if err != nil { return err } err = bindSocketRoute(winipcfg.AF_INET6, device, ourLuid, &lastLuid6) if err != nil { return err } if !autoMTU { return nil } mtu := uint32(0) if lastLuid4 != 0 { iface, err := winipcfg.InterfaceFromLUID(lastLuid4) if err != nil { return err } if iface.Mtu > 0 { mtu = iface.Mtu } } if lastLuid6 != 0 { iface, err := winipcfg.InterfaceFromLUID(lastLuid6) if err != nil { return err } if iface.Mtu > 0 && iface.Mtu < mtu { mtu = iface.Mtu } } if mtu > 0 && (lastMtu == 0 || lastMtu != mtu) { iface, err := winipcfg.GetIpInterface(ourLuid, winipcfg.AF_INET) if err != nil { return err } iface.NlMtu = mtu - 80 if iface.NlMtu < 576 { iface.NlMtu = 576 } err = iface.Set() if err != nil { return err } tun.ForceMTU(int(iface.NlMtu)) //TODO: it sort of breaks the model with v6 mtu and v4 mtu being different. Just set v4 one for now. iface, err = winipcfg.GetIpInterface(ourLuid, winipcfg.AF_INET6) if err != nil { return err } iface.NlMtu = mtu - 80 if iface.NlMtu < 1280 { iface.NlMtu = 1280 } err = iface.Set() if err != nil { return err } lastMtu = mtu } return nil } err = doIt() if err != nil { return nil, err } cb, err := winipcfg.RegisterRouteChangeCallback(func(notificationType winipcfg.MibNotificationType, route *winipcfg.Route) { //fmt.Printf("MonitorDefaultRoutes: changed: %v\n", route.DestinationPrefix) if route.DestinationPrefix.PrefixLength == 0 { _ = doIt() } }) if err != nil { return nil, err } return cb, nil } func setDNSDomains(g windows.GUID, dnsDomains []string) { gs := g.String() log.Printf("setDNSDomains(%v) guid=%v\n", dnsDomains, gs) p := `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + gs key, err := registry.OpenKey(registry.LOCAL_MACHINE, p, registry.READ|registry.SET_VALUE) if err != nil { log.Printf("setDNSDomains(%v): open: %v\n", p, err) return } defer key.Close() // Windows only supports a single per-interface DNS domain. dom := "" if len(dnsDomains) > 0 { dom = dnsDomains[0] } err = key.SetStringValue("Domain", dom) if err != nil { log.Printf("setDNSDomains(%v): SetStringValue: %v\n", p, err) } } func setFirewall(ifcGUID *windows.GUID) (bool, error) { c := ole.Connection{} err := c.Initialize() if err != nil { return false, fmt.Errorf("c.Initialize: %v", err) } defer c.Uninitialize() m, err := winnet.NewNetworkListManager(&c) if err != nil { return false, fmt.Errorf("winnet.NewNetworkListManager: %v", err) } defer m.Release() cl, err := m.GetNetworkConnections() if err != nil { return false, fmt.Errorf("m.GetNetworkConnections: %v", err) } defer cl.Release() for _, nco := range cl { aid, err := nco.GetAdapterId() if err != nil { return false, fmt.Errorf("nco.GetAdapterId: %v", err) } if aid != ifcGUID.String() { log.Printf("skipping adapter id: %v\n", aid) continue } log.Printf("found! adapter id: %v\n", aid) n, err := nco.GetNetwork() if err != nil { return false, fmt.Errorf("GetNetwork: %v", err) } defer n.Release() cat, err := n.GetCategory() if err != nil { return false, fmt.Errorf("GetCategory: %v", err) } if cat == 0 { err = n.SetCategory(1) if err != nil { return false, fmt.Errorf("SetCategory: %v", err) } } else { log.Printf("setFirewall: already category %v\n", cat) } return true, nil } return false, nil } func configureInterface(m *wgcfg.Config, tun *tun.NativeTun, dns []wgcfg.IP, dnsDomains []string) error { const mtu = 0 guid := tun.GUID() log.Printf("wintun GUID is %v\n", guid) iface, err := winipcfg.InterfaceFromGUID(&guid) if err != nil { return err } go func() { // It takes a weirdly long time for Windows to notice the // new interface has come up. Poll periodically until it // does. for i := 0; i < 20; i++ { found, err := setFirewall(&guid) if err != nil { log.Printf("setFirewall: %v\n", err) // fall through anyway, this isn't fatal. } if found { break } time.Sleep(1 * time.Second) } }() setDNSDomains(guid, dnsDomains) routes := []winipcfg.RouteData{} var firstGateway4 *net.IP var firstGateway6 *net.IP addresses := make([]*net.IPNet, len(m.Addresses)) for i, addr := range m.Addresses { ipnet := addr.IPNet() addresses[i] = ipnet gateway := ipnet.IP if addr.IP.Is4() && firstGateway4 == nil { firstGateway4 = &gateway } else if addr.IP.Is6() && firstGateway6 == nil { firstGateway6 = &gateway } } foundDefault4 := false foundDefault6 := false for _, peer := range m.Peers { for _, allowedip := range peer.AllowedIPs { if (allowedip.IP.Is4() && firstGateway4 == nil) || (allowedip.IP.Is6() && firstGateway6 == nil) { return errors.New("Due to a Windows limitation, one cannot have interface routes without an interface address") } ipn := allowedip.IPNet() var gateway net.IP if allowedip.IP.Is4() { gateway = *firstGateway4 } else if allowedip.IP.Is6() { gateway = *firstGateway6 } r := winipcfg.RouteData{ Destination: net.IPNet{ IP: ipn.IP.Mask(ipn.Mask), Mask: ipn.Mask, }, NextHop: gateway, Metric: 0, } if bytes.Compare(r.Destination.IP, gateway) == 0 { // no need to add a route for the interface's // own IP. The kernel does that for us. // If we try to replace it, we'll fail to // add the route unless NextHop is set, but // then the interface's IP won't be pingable. continue } if allowedip.IP.Is4() { if allowedip.Mask == 0 { foundDefault4 = true } r.NextHop = *firstGateway4 } else if allowedip.IP.Is6() { if allowedip.Mask == 0 { foundDefault6 = true } r.NextHop = *firstGateway6 } routes = append(routes, r) } } err = iface.SyncAddresses(addresses) if err != nil { return err } sort.Slice(routes, func(i, j int) bool { return (bytes.Compare(routes[i].Destination.IP, routes[j].Destination.IP) == -1 || // Narrower masks first bytes.Compare(routes[i].Destination.Mask, routes[j].Destination.Mask) == 1 || // No nexthop before non-empty nexthop bytes.Compare(routes[i].NextHop, routes[j].NextHop) == -1 || // Lower metrics first routes[i].Metric < routes[j].Metric) }) deduplicatedRoutes := []*winipcfg.RouteData{} for i := 0; i < len(routes); i++ { // There's only one way to get to a given IP+Mask, so delete // all matches after the first. if i > 0 && bytes.Equal(routes[i].Destination.IP, routes[i-1].Destination.IP) && bytes.Equal(routes[i].Destination.Mask, routes[i-1].Destination.Mask) { continue } deduplicatedRoutes = append(deduplicatedRoutes, &routes[i]) } log.Printf("routes: %v\n", routes) var errAcc error err = iface.SyncRoutes(deduplicatedRoutes) if err != nil && errAcc == nil { log.Printf("setroutes: %v\n", err) errAcc = err } var dnsIPs []net.IP for _, ip := range dns { dnsIPs = append(dnsIPs, ip.IP()) } err = iface.SetDNS(dnsIPs) if err != nil && errAcc == nil { log.Printf("setdns: %v\n", err) errAcc = err } ipif, err := iface.GetIpInterface(winipcfg.AF_INET) if err != nil { log.Printf("getipif: %v\n", err) return err } log.Printf("foundDefault4: %v\n", foundDefault4) if foundDefault4 { ipif.UseAutomaticMetric = false ipif.Metric = 0 } if mtu > 0 { ipif.NlMtu = uint32(mtu) tun.ForceMTU(int(ipif.NlMtu)) } err = ipif.Set() if err != nil && errAcc == nil { errAcc = err } ipif, err = iface.GetIpInterface(winipcfg.AF_INET6) if err != nil { return err } if err != nil && errAcc == nil { errAcc = err } if foundDefault6 { ipif.UseAutomaticMetric = false ipif.Metric = 0 } if mtu > 0 { ipif.NlMtu = uint32(mtu) } ipif.DadTransmits = 0 ipif.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled err = ipif.Set() if err != nil && errAcc == nil { errAcc = err } return errAcc }