ipn,cmd/tailscale/cli: add Bring Your Own Mullvad Account (BYOMA) support

Updates #cleanup

Signed-off-by: Karthik Vinayan <karthikdoestech@gmail.com>
Karthik Vinayan 1 day ago
parent 5f34f14e14
commit 254a5f0213
No known key found for this signature in database
GPG Key ID: 4578443BAF04D59C

@ -1002,6 +1002,9 @@ func TestPrefFlagMapping(t *testing.T) {
case "AutoExitNode":
// Handled by tailscale {set,up} --exit-node=auto:any.
continue
case "CustomMullvadAccount":
// Handled by tailscale set --mullvad-account, not tailscale up.
continue
}
t.Errorf("unexpected new ipn.Pref field %q is not handled by up.go (see addPrefFlagMapping and checkForAccidentalSettingReverts)", prefName)
}

@ -69,6 +69,7 @@ type setArgsT struct {
netfilterMode string
relayServerPort string
relayServerStaticEndpoints string
mullvadAccount string
}
func newSetFlagSet(goos string, setArgs *setArgsT) *flag.FlagSet {
@ -92,6 +93,7 @@ func newSetFlagSet(goos string, setArgs *setArgsT) *flag.FlagSet {
setf.BoolVar(&setArgs.sync, "sync", false, hidden+"actively sync configuration from the control plane (set to false only for network failure testing)")
setf.StringVar(&setArgs.relayServerPort, "relay-server-port", "", "UDP port number (0 will pick a random unused port) for the relay server to bind to, on all interfaces, or empty string to disable relay server functionality")
setf.StringVar(&setArgs.relayServerStaticEndpoints, "relay-server-static-endpoints", "", "static IP:port endpoints to advertise as candidates for relay connections (comma-separated, e.g. \"[2001:db8::1]:40000,192.0.2.1:40000\") or empty string to not advertise any static endpoints")
setf.StringVar(&setArgs.mullvadAccount, "mullvad-account", "", "personal Mullvad account number (16 digits) for exit node access, or empty string to disable")
ffcomplete.Flag(setf, "exit-node", func(args []string) ([]string, ffcomplete.ShellCompDirective, error) {
st, err := localClient.Status(context.Background())
@ -163,8 +165,9 @@ func runSet(ctx context.Context, args []string) (retErr error) {
AppConnector: ipn.AppConnectorPrefs{
Advertise: setArgs.advertiseConnector,
},
PostureChecking: setArgs.reportPosture,
NoStatefulFiltering: opt.NewBool(!setArgs.statefulFiltering),
PostureChecking: setArgs.reportPosture,
NoStatefulFiltering: opt.NewBool(!setArgs.statefulFiltering),
CustomMullvadAccount: setArgs.mullvadAccount,
},
}

@ -36,4 +36,10 @@ const (
// ArgServerName provides a Warnable with comma delimited list of the hostname of the servers involved in the unhealthy state.
// If no nameservers were available to query, this will be an empty string.
ArgDNSServers Arg = "dns-servers"
// ArgDaysRemaining provides a Warnable with the number of days remaining before an account or subscription expires.
ArgDaysRemaining Arg = "days-remaining"
// ArgExpiryDate provides a Warnable with the expiry date of an account or subscription.
ArgExpiryDate Arg = "expiry-date"
)

@ -298,3 +298,46 @@ var warmingUpWarnable = condRegister(func() *Warnable {
Text: StaticMessage("Tailscale is starting. Please wait."),
}
})
// customMullvadExpiringWarnable is a Warnable that warns the user that their custom Mullvad account is about to expire.
var customMullvadExpiringWarnable = condRegister(func() *Warnable {
return &Warnable{
Code: tsconst.HealthWarnableCustomMullvadExpiring,
Title: "Mullvad account expiring soon",
Severity: SeverityMedium,
Text: func(args Args) string {
return fmt.Sprintf("Your custom Mullvad account expires in %s days. Renew to keep using Mullvad exit nodes.", args[ArgDaysRemaining])
},
}
})
// customMullvadExpiredWarnable is a Warnable that warns the user that their custom Mullvad account has expired.
var customMullvadExpiredWarnable = condRegister(func() *Warnable {
return &Warnable{
Code: tsconst.HealthWarnableCustomMullvadExpired,
Title: "Mullvad account expired",
Severity: SeverityHigh,
Text: func(args Args) string {
return fmt.Sprintf("Your custom Mullvad account expired on %s. Renew to restore Mullvad exit node functionality.", args[ArgExpiryDate])
},
}
})
// customMullvadAuthFailedWarnable is a Warnable that warns the user that authentication with their custom Mullvad account failed.
var customMullvadAuthFailedWarnable = condRegister(func() *Warnable {
return &Warnable{
Code: tsconst.HealthWarnableCustomMullvadAuthFailed,
Title: "Mullvad authentication failed",
Severity: SeverityMedium,
Text: func(args Args) string {
return fmt.Sprintf("Failed to authenticate with your custom Mullvad account: %v", args[ArgError])
},
}
})
// Exported warnables for custom Mullvad integration
var (
CustomMullvadExpiringWarnable = customMullvadExpiringWarnable
CustomMullvadExpiredWarnable = customMullvadExpiredWarnable
CustomMullvadAuthFailedWarnable = customMullvadAuthFailedWarnable
)

@ -104,6 +104,7 @@ var _PrefsCloneNeedsRegeneration = Prefs(struct {
DriveShares []*drive.Share
RelayServerPort *uint16
RelayServerStaticEndpoints []netip.AddrPort
CustomMullvadAccount string
AllowSingleHosts marshalAsTrueInJSON
Persist *persist.Persist
}{})

@ -453,6 +453,12 @@ func (v PrefsView) RelayServerStaticEndpoints() views.Slice[netip.AddrPort] {
return views.SliceOf(v.ж.RelayServerStaticEndpoints)
}
// CustomMullvadAccount is the 16-digit Mullvad account number for
// "Bring Your Own Mullvad Account" (BYOMA) integration. When set,
// Tailscale will register its WireGuard key with Mullvad and expose
// Mullvad servers as exit nodes.
func (v PrefsView) CustomMullvadAccount() string { return v.ж.CustomMullvadAccount }
// AllowSingleHosts was a legacy field that was always true
// for the past 4.5 years. It controlled whether Tailscale
// peers got /32 or /128 routes for each other.
@ -506,6 +512,7 @@ var _PrefsViewNeedsRegeneration = Prefs(struct {
DriveShares []*drive.Share
RelayServerPort *uint16
RelayServerStaticEndpoints []netip.AddrPort
CustomMullvadAccount string
AllowSingleHosts marshalAsTrueInJSON
Persist *persist.Persist
}{})

@ -354,7 +354,7 @@ func TestDNSConfigForNetmap(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
verOS := cmp.Or(tt.os, "linux")
var log tstest.MemLogger
got := dnsConfigForNetmap(tt.nm, peersMap(tt.peers), tt.prefs.View(), tt.expired, log.Logf, verOS)
got := dnsConfigForNetmap(tt.nm, peersMap(tt.peers), tt.prefs.View(), tt.expired, log.Logf, verOS, false)
if !reflect.DeepEqual(got, tt.want) {
gotj, _ := json.MarshalIndent(got, "", "\t")
wantj, _ := json.MarshalIndent(tt.want, "", "\t")

@ -347,6 +347,10 @@ type LocalBackend struct {
// avoid unnecessary churn between multiple equally-good options.
lastSuggestedExitNode tailcfg.StableNodeID
// customMullvadState holds state for the "Bring Your Own Mullvad Account" feature.
// Protected by mu.
customMullvadState *customMullvadState
// allowedSuggestedExitNodes is a set of exit nodes permitted by the most recent
// [pkey.AllowedSuggestedExitNodes] value. The allowedSuggestedExitNodesMu
// mutex guards access to this set.
@ -1357,11 +1361,21 @@ func (b *LocalBackend) populatePeerStatusLocked(sb *ipnstate.StatusBuilder) {
sb.AddUser(id, up)
}
exitNodeID := b.pm.CurrentPrefs().ExitNodeID()
for _, p := range cn.Peers() {
// Collect all peers, including custom Mullvad peers
peers := cn.Peers()
customMullvadPeers := b.getCustomMullvadPeersLocked()
allPeers := make([]tailcfg.NodeView, 0, len(peers)+len(customMullvadPeers))
allPeers = append(allPeers, peers...)
allPeers = append(allPeers, customMullvadPeers...)
for _, p := range allPeers {
tailscaleIPs := make([]netip.Addr, 0, p.Addresses().Len())
isWireGuardOnly := p.IsWireGuardOnly()
for i := range p.Addresses().Len() {
addr := p.Addresses().At(i)
if addr.IsSingleIP() && tsaddr.IsTailscaleIP(addr.Addr()) {
// Include Tailscale IPs, or any IP for WireGuard-only nodes (like Mullvad)
if addr.IsSingleIP() && (tsaddr.IsTailscaleIP(addr.Addr()) || isWireGuardOnly) {
tailscaleIPs = append(tailscaleIPs, addr.Addr())
}
}
@ -4505,6 +4519,33 @@ func (b *LocalBackend) onEditPrefsLocked(_ ipnauth.Actor, mp *ipn.MaskedPrefs, o
lastSuggestedExitNode: b.lastSuggestedExitNode,
}
e.record()
// Handle custom Mullvad account changes.
if mp.CustomMullvadAccountSet {
if oldPrefs.CustomMullvadAccount() != newPrefs.CustomMullvadAccount() {
account := newPrefs.CustomMullvadAccount()
// Use a background goroutine since API calls may block.
go func() {
b.logf("mullvad: starting configuration")
ctx, cancel := context.WithTimeout(b.ctx, 30*time.Second)
defer cancel()
if err := b.ConfigureCustomMullvad(ctx, account); err != nil {
b.logf("mullvad: failed to configure: %v", err)
} else {
b.logf("mullvad: configuration successful")
// Update engine's netmap to include custom peers for PeerForIP()
b.mu.Lock()
nm := b.currentNode().netMapWithPeers()
b.mu.Unlock()
if nm != nil {
b.e.SetNetworkMap(nm)
}
// Trigger WireGuard reconfiguration to include custom Mullvad peers
b.authReconfig()
}
}()
}
}
}
// startReconnectTimerLocked sets a timer to automatically set WantRunning to true
@ -5063,7 +5104,9 @@ func (b *LocalBackend) authReconfigLocked() {
cn := b.currentNode()
nm := cn.NetMap()
// Use netMapWithPeers to include custom peers (e.g., custom Mullvad exit nodes)
// in the WireGuard configuration.
nm := cn.netMapWithPeers()
if nm == nil {
b.logf("[v1] authReconfig: netmap not yet valid. Skipping.")
return
@ -5115,6 +5158,32 @@ func (b *LocalBackend) authReconfigLocked() {
return
}
// Inject custom Mullvad exit node peer into WireGuard config.
// The WGCfg function only knows about control-plane peers, so we need
// to add custom peers (like Mullvad) manually.
if exitNodeID := prefs.ExitNodeID(); exitNodeID != "" && strings.HasPrefix(string(exitNodeID), "custom-mullvad-") {
if peer, ok := cn.PeerByStableID(exitNodeID); ok {
wgPeer := wgcfg.Peer{
PublicKey: peer.Key(),
}
// Add AllowedIPs - for exit node this includes 0.0.0.0/0 and ::/0
for _, aip := range peer.AllowedIPs().All() {
wgPeer.AllowedIPs = append(wgPeer.AllowedIPs, aip)
}
// Set masquerade addresses - Mullvad expects traffic from our assigned IPs,
// not our Tailscale IP. The deviceInfo contains the IPs Mullvad assigned to us.
if b.customMullvadState != nil && b.customMullvadState.deviceInfo != nil {
if v4 := b.customMullvadState.deviceInfo.IPv4Address; v4.IsValid() {
wgPeer.V4MasqAddr = &v4
}
if v6 := b.customMullvadState.deviceInfo.IPv6Address; v6.IsValid() {
wgPeer.V6MasqAddr = &v6
}
}
cfg.Peers = append(cfg.Peers, wgPeer)
}
}
oneCGNATRoute := shouldUseOneCGNATRoute(b.logf, b.sys.NetMon.Get(), b.sys.ControlKnobs(), version.OS())
rcfg := b.routerConfigLocked(cfg, prefs, oneCGNATRoute)
@ -5476,7 +5545,21 @@ func (b *LocalBackend) routerConfigLocked(cfg *wgcfg.Config, prefs ipn.PrefsView
// likely to break some functionality, but if the user expressed a
// preference for routing remotely, we want to avoid leaking
// traffic at the expense of functionality.
if buildfeatures.HasUseExitNode && (prefs.ExitNodeID() != "" || prefs.ExitNodeIP().IsValid()) {
//
// For custom Mullvad exit nodes, we need to skip exit node routing
// until the Mullvad peers are configured. Otherwise, all traffic
// (including Mullvad API calls needed to configure peers) will try
// to route through a non-existent exit node and fail.
exitNodeReady := true
if exitNodeID := prefs.ExitNodeID(); exitNodeID != "" && strings.HasPrefix(string(exitNodeID), "custom-mullvad-") {
// Check if Mullvad is configured with peers.
// Use b.customMullvadState.peers (protected by b.mu) instead of
// b.currentNode().customPeers (which requires nodeBackend.mu).
if b.customMullvadState == nil || len(b.customMullvadState.peers) == 0 {
exitNodeReady = false
}
}
if buildfeatures.HasUseExitNode && exitNodeReady && (prefs.ExitNodeID() != "" || prefs.ExitNodeIP().IsValid()) {
var default4, default6 bool
for _, route := range rs.Routes {
switch route {
@ -5629,6 +5712,47 @@ func (b *LocalBackend) enterStateLocked(newState ipn.State) {
// necessary and add unit tests to cover those cases, or remove it.
if oldState != ipn.Running {
b.resetAuthURLLocked()
// Initialize or refresh custom Mullvad integration.
// Case 1: Account is configured in prefs but state not initialized (daemon restart)
// Case 2: State exists and may need refresh (key rotation during re-auth)
if account := prefs.CustomMullvadAccount(); account != "" && envknob.Bool("TS_ENABLE_CUSTOM_MULLVAD") {
if b.customMullvadState == nil {
// Daemon restart with persisted account - initialize from saved config
b.logf("mullvad: initializing from saved account on startup")
go func() {
ctx, cancel := context.WithTimeout(b.ctx, 30*time.Second)
defer cancel()
if err := b.ConfigureCustomMullvad(ctx, account); err != nil {
b.logf("mullvad: failed to initialize on startup: %v", err)
} else {
b.logf("mullvad: initialized from saved account")
// Update engine's netmap to include custom peers for PeerForIP()
b.mu.Lock()
nm := b.currentNode().netMapWithPeers()
b.mu.Unlock()
if nm != nil {
b.e.SetNetworkMap(nm)
}
b.authReconfig()
}
}()
} else if b.customMullvadState.client != nil {
// Re-register with Mullvad in case the node key changed
// during re-authentication. The Tailscale node key is used for
// Mullvad device registration.
go func() {
ctx, cancel := context.WithTimeout(b.ctx, 30*time.Second)
defer cancel()
if err := b.RefreshCustomMullvad(ctx); err != nil {
b.logf("mullvad: failed to re-register after auth: %v", err)
} else {
b.logf("mullvad: re-registered after auth")
b.authReconfig()
}
}()
}
}
}
// Start a captive portal detection loop if none has been

@ -2295,7 +2295,7 @@ func TestDNSConfigForNetmapForExitNodeConfigs(t *testing.T) {
}
prefs := &ipn.Prefs{ExitNodeID: tc.exitNode, CorpDNS: true}
got := dnsConfigForNetmap(nm, peersMap(tc.peers), prefs.View(), false, t.Logf, "")
got := dnsConfigForNetmap(nm, peersMap(tc.peers), prefs.View(), false, t.Logf, "", false)
if !resolversEqual(t, got.DefaultResolvers, tc.wantDefaultResolvers) {
t.Errorf("DefaultResolvers: got %#v, want %#v", got.DefaultResolvers, tc.wantDefaultResolvers)
}

@ -0,0 +1,738 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package mullvad provides a client for the Mullvad VPN API,
// enabling "Bring Your Own Mullvad Account" functionality.
package mullvad
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/netip"
"net/url"
"strings"
"sync"
"time"
"go4.org/mem"
"tailscale.com/envknob"
"tailscale.com/net/dnscache"
"tailscale.com/net/netx"
"tailscale.com/types/key"
"tailscale.com/types/logger"
)
// CustomMullvadEnabled reports whether custom Mullvad account support is enabled.
func CustomMullvadEnabled() bool {
return envknob.Bool("TS_ENABLE_CUSTOM_MULLVAD")
}
const (
// defaultAPIBase is the base URL for the Mullvad API.
defaultAPIBase = "https://api.mullvad.net"
// defaultWireGuardPort is the default WireGuard port for Mullvad servers.
defaultWireGuardPort = 51820
// serverCacheExpiry is how long to cache the server list.
serverCacheExpiry = 6 * time.Hour
// tokenRefreshMargin is how long before expiry to refresh the token.
tokenRefreshMargin = 5 * time.Minute
// maxAPIResponseSize is the maximum size of API response bodies we'll read.
maxAPIResponseSize = 1 << 20 // 1MB
)
var (
// ErrAccountExpired is returned when the Mullvad account has expired.
ErrAccountExpired = errors.New("mullvad account has expired")
// ErrInvalidAccount is returned when the account number is invalid.
ErrInvalidAccount = errors.New("invalid mullvad account number")
// ErrDeviceLimitReached is returned when the device limit is reached.
ErrDeviceLimitReached = errors.New("mullvad device limit reached")
// ErrNotEnabled is returned when custom Mullvad support is not enabled.
ErrNotEnabled = errors.New("custom mullvad support not enabled")
)
// Server represents a Mullvad WireGuard server.
type Server struct {
Hostname string // "us-nyc-wg-001"
IPv4 netip.Addr // Server's IPv4 address
IPv6 netip.Addr // Server's IPv6 address
PublicKey key.NodePublic // Server's WireGuard public key
Port uint16 // WireGuard port (usually 51820)
CountryCode string // ISO 3166-1 alpha-2 ("us")
CountryName string // "USA"
CityCode string // "nyc"
CityName string // "New York City"
Active bool // Whether the server is currently active
Owned bool // Whether Mullvad owns this server
}
// AccountStatus contains account information.
type AccountStatus struct {
Expiry time.Time
DaysLeft int
IsExpired bool
}
// DeviceInfo represents a registered device.
type DeviceInfo struct {
ID string
PublicKey key.NodePublic
IPv4Address netip.Addr
IPv6Address netip.Addr
Created time.Time
}
// API request types for JSON marshaling.
type (
tokenRequest struct {
AccountNumber string `json:"account_number"`
}
deviceRegisterRequest struct {
Pubkey string `json:"pubkey"`
HijackDNS bool `json:"hijack_dns"`
}
deviceRotateKeyRequest struct {
Pubkey string `json:"pubkey"`
}
)
// Client handles communication with the Mullvad API.
type Client struct {
httpClient *http.Client
logf logger.Logf
apiBase string
mu sync.Mutex
accountNum string
accessToken string
tokenExpiry time.Time
deviceID string
deviceInfo *DeviceInfo
// Cached data
servers []Server
serversFetched time.Time
}
// DialContextFunc is a function that dials a network connection.
// This allows the caller to provide a custom dialer that bypasses the Tailscale tunnel.
type DialContextFunc func(ctx context.Context, network, addr string) (net.Conn, error)
// NewClient creates a new Mullvad API client.
// dialFunc is used for HTTP connections (e.g., to bypass Tailscale tunnel).
// dnsResolver is used for DNS lookups (to bypass MagicDNS during bootstrap).
// If dialFunc is nil, the default system dialer is used.
// If dnsResolver is nil, DNS lookups use the system resolver directly.
func NewClient(accountNumber string, logf logger.Logf, dialFunc DialContextFunc, dnsResolver *dnscache.Resolver) (*Client, error) {
if !CustomMullvadEnabled() {
return nil, ErrNotEnabled
}
if !isValidAccountNumber(accountNumber) {
return nil, ErrInvalidAccount
}
httpClient := &http.Client{Timeout: 30 * time.Second}
if dialFunc != nil {
tr := http.DefaultTransport.(*http.Transport).Clone()
if dnsResolver != nil {
// Wrap dialFunc with DNS caching that bypasses MagicDNS.
// This solves the bootstrap problem where MagicDNS routes to Mullvad DNS
// which is unreachable before the Mullvad tunnel is established.
tr.DialContext = dnscache.Dialer(netx.DialFunc(dialFunc), dnsResolver)
} else {
tr.DialContext = dialFunc
}
httpClient.Transport = tr
}
return &Client{
httpClient: httpClient,
logf: logf,
apiBase: defaultAPIBase,
accountNum: accountNumber,
}, nil
}
// isValidAccountNumber checks if the account number is valid (16 digits).
func isValidAccountNumber(num string) bool {
if len(num) != 16 {
return false
}
for _, c := range num {
if c < '0' || c > '9' {
return false
}
}
return true
}
// tokenResponse represents the response from the token endpoint.
type tokenResponse struct {
AccessToken string `json:"access_token"`
Expiry time.Time `json:"expiry"`
}
// Authenticate obtains an access token from Mullvad.
func (c *Client) Authenticate(ctx context.Context) error {
// Fast path: check if we already have a valid token.
c.mu.Lock()
if c.accessToken != "" && time.Now().Add(tokenRefreshMargin).Before(c.tokenExpiry) {
c.mu.Unlock()
return nil
}
// Copy values needed for the request while holding the lock.
accountNum := c.accountNum
apiBase := c.apiBase
c.mu.Unlock()
// Do HTTP request without holding the lock to avoid blocking concurrent operations.
reqBodyJSON, err := json.Marshal(tokenRequest{AccountNumber: accountNum})
if err != nil {
return fmt.Errorf("marshaling auth request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", apiBase+"/auth/v1/token", bytes.NewReader(reqBodyJSON))
if err != nil {
return fmt.Errorf("creating auth request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return fmt.Errorf("auth request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusNotFound {
return ErrInvalidAccount
}
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, maxAPIResponseSize))
return fmt.Errorf("auth failed with status %d: %s", resp.StatusCode, string(body))
}
var tokenResp tokenResponse
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
return fmt.Errorf("decoding auth response: %w", err)
}
// Reacquire lock to update state.
c.mu.Lock()
c.accessToken = tokenResp.AccessToken
c.tokenExpiry = tokenResp.Expiry
c.mu.Unlock()
c.logf("mullvad: authenticated, token expires at %v", tokenResp.Expiry)
return nil
}
// accountResponse represents the response from the public account endpoint.
// Endpoint: GET https://api.mullvad.net/public/accounts/v1/{account}
// Example response: {"id":"1234567890123456","expiry":"2026-02-01T20:01:32+00:00"}
type accountResponse struct {
ID string `json:"id"`
Expiry time.Time `json:"expiry"`
}
// GetAccountStatus checks account expiry using the public API endpoint.
func (c *Client) GetAccountStatus(ctx context.Context) (*AccountStatus, error) {
// Use the public API endpoint that doesn't require auth
req, err := http.NewRequestWithContext(ctx, "GET", c.apiBase+"/public/accounts/v1/"+c.accountNum, nil)
if err != nil {
return nil, fmt.Errorf("creating account status request: %w", err)
}
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("account status request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
return nil, ErrInvalidAccount
}
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, maxAPIResponseSize))
return nil, fmt.Errorf("account status failed with status %d: %s", resp.StatusCode, string(body))
}
var accResp accountResponse
if err := json.NewDecoder(resp.Body).Decode(&accResp); err != nil {
return nil, fmt.Errorf("decoding account response: %w", err)
}
now := time.Now()
daysLeft := int(accResp.Expiry.Sub(now).Hours() / 24)
if daysLeft < 0 {
daysLeft = 0
}
return &AccountStatus{
Expiry: accResp.Expiry,
DaysLeft: daysLeft,
IsExpired: now.After(accResp.Expiry),
}, nil
}
// deviceResponse represents a device from the API.
type deviceResponse struct {
ID string `json:"id"`
Pubkey string `json:"pubkey"`
IPv4Address string `json:"ipv4_address"`
IPv6Address string `json:"ipv6_address"`
Created string `json:"created"`
}
// RegisterDevice registers a WireGuard public key with Mullvad.
func (c *Client) RegisterDevice(ctx context.Context, pubkey key.NodePublic) (*DeviceInfo, error) {
if err := c.Authenticate(ctx); err != nil {
return nil, fmt.Errorf("authentication failed: %w", err)
}
c.mu.Lock()
token := c.accessToken
c.mu.Unlock()
// Encode the public key as base64
pubkeyBytes := pubkey.Raw32()
pubkeyB64 := base64.StdEncoding.EncodeToString(pubkeyBytes[:])
reqBodyJSON, err := json.Marshal(deviceRegisterRequest{Pubkey: pubkeyB64, HijackDNS: false})
if err != nil {
return nil, fmt.Errorf("marshaling register request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", c.apiBase+"/accounts/v1/devices", bytes.NewReader(reqBodyJSON))
if err != nil {
return nil, fmt.Errorf("creating register request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+token)
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("register request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusConflict {
// Device already registered, try to find it
return c.findExistingDevice(ctx, pubkey)
}
if resp.StatusCode == http.StatusForbidden {
return nil, ErrDeviceLimitReached
}
if resp.StatusCode == http.StatusBadRequest {
// Check if it's a "pubkey already in use" error
body, _ := io.ReadAll(io.LimitReader(resp.Body, maxAPIResponseSize))
if strings.Contains(string(body), "PUBKEY_IN_USE") || strings.Contains(string(body), "already in use") {
// Key already registered, try to find it
c.logf("mullvad: key already registered, looking up existing device")
return c.findExistingDevice(ctx, pubkey)
}
return nil, fmt.Errorf("register failed with status %d: %s", resp.StatusCode, string(body))
}
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
body, _ := io.ReadAll(io.LimitReader(resp.Body, maxAPIResponseSize))
return nil, fmt.Errorf("register failed with status %d: %s", resp.StatusCode, string(body))
}
var devResp deviceResponse
if err := json.NewDecoder(resp.Body).Decode(&devResp); err != nil {
return nil, fmt.Errorf("decoding register response: %w", err)
}
info, err := parseDeviceResponse(&devResp)
if err != nil {
return nil, err
}
c.mu.Lock()
c.deviceID = info.ID
c.deviceInfo = info
c.mu.Unlock()
c.logf("mullvad: registered device %s with IPv4 %v", info.ID, info.IPv4Address)
return info, nil
}
func parseDeviceResponse(resp *deviceResponse) (*DeviceInfo, error) {
// Parse public key
pubkeyBytes, err := base64.StdEncoding.DecodeString(resp.Pubkey)
if err != nil {
return nil, fmt.Errorf("decoding pubkey: %w", err)
}
if len(pubkeyBytes) != 32 {
return nil, fmt.Errorf("invalid pubkey length: %d", len(pubkeyBytes))
}
pubkey := key.NodePublicFromRaw32(mem.B(pubkeyBytes))
// Parse IPv4 address (strip the /32 suffix if present)
ipv4Str := strings.TrimSuffix(resp.IPv4Address, "/32")
ipv4, err := netip.ParseAddr(ipv4Str)
if err != nil {
return nil, fmt.Errorf("parsing IPv4: %w", err)
}
// Parse IPv6 address (strip the /128 suffix if present)
ipv6Str := strings.TrimSuffix(resp.IPv6Address, "/128")
ipv6, err := netip.ParseAddr(ipv6Str)
if err != nil {
return nil, fmt.Errorf("parsing IPv6: %w", err)
}
return &DeviceInfo{
ID: resp.ID,
PublicKey: pubkey,
IPv4Address: ipv4,
IPv6Address: ipv6,
}, nil
}
// findExistingDevice finds a device by its public key.
func (c *Client) findExistingDevice(ctx context.Context, pubkey key.NodePublic) (*DeviceInfo, error) {
devices, err := c.ListDevices(ctx)
if err != nil {
return nil, fmt.Errorf("listing devices: %w", err)
}
pubkeyBytes := pubkey.Raw32()
pubkeyB64 := base64.StdEncoding.EncodeToString(pubkeyBytes[:])
for _, dev := range devices {
devKeyBytes := dev.PublicKey.Raw32()
devKeyB64 := base64.StdEncoding.EncodeToString(devKeyBytes[:])
if devKeyB64 == pubkeyB64 {
c.mu.Lock()
c.deviceID = dev.ID
c.deviceInfo = dev
c.mu.Unlock()
return dev, nil
}
}
return nil, fmt.Errorf("device with key not found")
}
// ListDevices returns all devices registered to the account.
func (c *Client) ListDevices(ctx context.Context) ([]*DeviceInfo, error) {
if err := c.Authenticate(ctx); err != nil {
return nil, fmt.Errorf("authentication failed: %w", err)
}
c.mu.Lock()
token := c.accessToken
c.mu.Unlock()
req, err := http.NewRequestWithContext(ctx, "GET", c.apiBase+"/accounts/v1/devices", nil)
if err != nil {
return nil, fmt.Errorf("creating list devices request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+token)
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("list devices request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, maxAPIResponseSize))
return nil, fmt.Errorf("list devices failed with status %d: %s", resp.StatusCode, string(body))
}
var devicesResp []deviceResponse
if err := json.NewDecoder(resp.Body).Decode(&devicesResp); err != nil {
return nil, fmt.Errorf("decoding devices response: %w", err)
}
devices := make([]*DeviceInfo, 0, len(devicesResp))
for _, d := range devicesResp {
info, err := parseDeviceResponse(&d)
if err != nil {
c.logf("mullvad: skipping device %s: %v", d.ID, err)
continue
}
devices = append(devices, info)
}
return devices, nil
}
// relayResponse represents a relay from the /public/relays/wireguard/v2/ API.
// Example: {"hostname":"us-nyc-wg-001","location":"us-nyc","active":true,"owned":false,
// "provider":"M247","ipv4_addr_in":"146.70.198.66","include_in_country":true,
// "weight":100,"public_key":"TUCaQc26/R6AGpkDUr8A8ytUs/e5+UVlIVujbuBwlzI=",
// "ipv6_addr_in":"2a0d:5600:9:c::f001"}
type relayResponse struct {
Hostname string `json:"hostname"`
Location string `json:"location"`
Active bool `json:"active"`
Owned bool `json:"owned"`
Provider string `json:"provider"`
IPv4AddrIn string `json:"ipv4_addr_in"`
IPv6AddrIn string `json:"ipv6_addr_in"`
PublicKey string `json:"public_key"`
IncludeInCountry bool `json:"include_in_country"`
Weight int `json:"weight"`
}
// locationInfo represents a location from the relay list.
type locationInfo struct {
Country string `json:"country"`
City string `json:"city"`
Latitude float64 `json:"latitude"`
Longitude float64 `json:"longitude"`
}
// relayListResponse represents the response from /public/relays/wireguard/v2/.
type relayListResponse struct {
Locations map[string]locationInfo `json:"locations"`
WireGuard struct {
Relays []relayResponse `json:"relays"`
} `json:"wireguard"`
}
// GetServers fetches the list of available Mullvad WireGuard servers.
func (c *Client) GetServers(ctx context.Context) ([]Server, error) {
c.mu.Lock()
// Check cache
if len(c.servers) > 0 && time.Since(c.serversFetched) < serverCacheExpiry {
servers := c.servers
c.mu.Unlock()
return servers, nil
}
c.mu.Unlock()
// Fetch the relay list from the v2 API
req, err := http.NewRequestWithContext(ctx, "GET", c.apiBase+"/public/relays/wireguard/v2/", nil)
if err != nil {
return nil, fmt.Errorf("creating servers request: %w", err)
}
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("servers request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, maxAPIResponseSize))
return nil, fmt.Errorf("servers request failed with status %d: %s", resp.StatusCode, string(body))
}
var relayList relayListResponse
if err := json.NewDecoder(resp.Body).Decode(&relayList); err != nil {
return nil, fmt.Errorf("decoding servers response: %w", err)
}
servers := make([]Server, 0, len(relayList.WireGuard.Relays))
for _, r := range relayList.WireGuard.Relays {
// Parse IPv4
ipv4, err := netip.ParseAddr(r.IPv4AddrIn)
if err != nil {
c.logf("mullvad: skipping server %s: invalid IPv4: %v", r.Hostname, err)
continue
}
// Parse IPv6
var ipv6 netip.Addr
if r.IPv6AddrIn != "" {
ipv6, err = netip.ParseAddr(r.IPv6AddrIn)
if err != nil {
c.logf("mullvad: server %s has invalid IPv6: %v", r.Hostname, err)
// Don't skip, IPv6 is optional
}
}
// Parse public key
pubkeyBytes, err := base64.StdEncoding.DecodeString(r.PublicKey)
if err != nil {
c.logf("mullvad: skipping server %s: invalid pubkey: %v", r.Hostname, err)
continue
}
if len(pubkeyBytes) != 32 {
c.logf("mullvad: skipping server %s: wrong pubkey length", r.Hostname)
continue
}
pubkey := key.NodePublicFromRaw32(mem.B(pubkeyBytes))
// Get location info
loc := relayList.Locations[r.Location]
// Location code format is like "us-nyc", split to get country and city codes
countryCode := ""
cityCode := ""
if parts := strings.SplitN(r.Location, "-", 2); len(parts) == 2 {
countryCode = parts[0]
cityCode = parts[1]
}
servers = append(servers, Server{
Hostname: r.Hostname,
IPv4: ipv4,
IPv6: ipv6,
PublicKey: pubkey,
Port: defaultWireGuardPort,
CountryCode: countryCode,
CountryName: loc.Country,
CityCode: cityCode,
CityName: loc.City,
Active: r.Active,
Owned: r.Owned,
})
}
c.mu.Lock()
c.servers = servers
c.serversFetched = time.Now()
c.mu.Unlock()
c.logf("mullvad: fetched %d WireGuard servers", len(servers))
return servers, nil
}
// GetDeviceInfo returns the current device info.
func (c *Client) GetDeviceInfo() *DeviceInfo {
c.mu.Lock()
defer c.mu.Unlock()
return c.deviceInfo
}
// AccountNumber returns the account number.
func (c *Client) AccountNumber() string {
c.mu.Lock()
defer c.mu.Unlock()
return c.accountNum
}
// MaskedAccountNumber returns the account number with middle digits masked.
func (c *Client) MaskedAccountNumber() string {
c.mu.Lock()
num := c.accountNum
c.mu.Unlock()
if len(num) != 16 {
return "invalid"
}
return num[:4] + "********" + num[12:]
}
// RotateKey rotates the registered WireGuard key.
func (c *Client) RotateKey(ctx context.Context, newPubkey key.NodePublic) error {
if err := c.Authenticate(ctx); err != nil {
return fmt.Errorf("authentication failed: %w", err)
}
c.mu.Lock()
token := c.accessToken
deviceID := c.deviceID
c.mu.Unlock()
if deviceID == "" {
return fmt.Errorf("no device registered")
}
pubkeyBytes := newPubkey.Raw32()
pubkeyB64 := base64.StdEncoding.EncodeToString(pubkeyBytes[:])
reqBodyJSON, err := json.Marshal(deviceRotateKeyRequest{Pubkey: pubkeyB64})
if err != nil {
return fmt.Errorf("marshaling rotate key request: %w", err)
}
endpoint := fmt.Sprintf("%s/accounts/v1/devices/%s/pubkey", c.apiBase, url.PathEscape(deviceID))
req, err := http.NewRequestWithContext(ctx, "PUT", endpoint, bytes.NewReader(reqBodyJSON))
if err != nil {
return fmt.Errorf("creating rotate key request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+token)
resp, err := c.httpClient.Do(req)
if err != nil {
return fmt.Errorf("rotate key request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent {
body, _ := io.ReadAll(io.LimitReader(resp.Body, maxAPIResponseSize))
return fmt.Errorf("rotate key failed with status %d: %s", resp.StatusCode, string(body))
}
c.logf("mullvad: rotated key for device %s", deviceID)
return nil
}
// RemoveDevice removes the registered device.
func (c *Client) RemoveDevice(ctx context.Context) error {
if err := c.Authenticate(ctx); err != nil {
return fmt.Errorf("authentication failed: %w", err)
}
c.mu.Lock()
token := c.accessToken
deviceID := c.deviceID
c.mu.Unlock()
if deviceID == "" {
return nil // Nothing to remove
}
endpoint := fmt.Sprintf("%s/accounts/v1/devices/%s", c.apiBase, url.PathEscape(deviceID))
req, err := http.NewRequestWithContext(ctx, "DELETE", endpoint, nil)
if err != nil {
return fmt.Errorf("creating remove device request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+token)
resp, err := c.httpClient.Do(req)
if err != nil {
return fmt.Errorf("remove device request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusNotFound {
body, _ := io.ReadAll(io.LimitReader(resp.Body, maxAPIResponseSize))
return fmt.Errorf("remove device failed with status %d: %s", resp.StatusCode, string(body))
}
c.mu.Lock()
c.deviceID = ""
c.deviceInfo = nil
c.mu.Unlock()
c.logf("mullvad: removed device")
return nil
}
// VerifyConnection checks if traffic is routing through Mullvad.
func (c *Client) VerifyConnection(ctx context.Context) (bool, error) {
req, err := http.NewRequestWithContext(ctx, "GET", "https://am.i.mullvad.net/connected", nil)
if err != nil {
return false, fmt.Errorf("creating verify request: %w", err)
}
resp, err := c.httpClient.Do(req)
if err != nil {
return false, fmt.Errorf("verify request failed: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(io.LimitReader(resp.Body, maxAPIResponseSize))
if err != nil {
return false, fmt.Errorf("reading verify response: %w", err)
}
return strings.Contains(string(body), "You are connected to Mullvad"), nil
}

@ -0,0 +1,390 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package mullvad
import (
"context"
"encoding/base64"
"encoding/json"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"
"tailscale.com/types/key"
"tailscale.com/types/logger"
)
func TestIsValidAccountNumber(t *testing.T) {
tests := []struct {
name string
input string
want bool
}{
{"valid", "1234567890123456", true},
{"too short", "12345678901234", false},
{"too long", "12345678901234567", false},
{"empty", "", false},
{"with letters", "123456789012345a", false},
{"with spaces", "1234567890123 56", false},
{"with dashes", "1234-5678-9012-3456", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := isValidAccountNumber(tt.input)
if got != tt.want {
t.Errorf("isValidAccountNumber(%q) = %v, want %v", tt.input, got, tt.want)
}
})
}
}
func TestNewClient(t *testing.T) {
// Note: This test requires TS_ENABLE_CUSTOM_MULLVAD=1 to pass
t.Setenv("TS_ENABLE_CUSTOM_MULLVAD", "1")
tests := []struct {
name string
account string
wantErr error
wantNil bool
}{
{"valid account", "1234567890123456", nil, false},
{"invalid account short", "123456", ErrInvalidAccount, true},
{"invalid account letters", "abcdef1234567890", ErrInvalidAccount, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(tt.account, logger.Discard, nil, nil)
if err != tt.wantErr {
t.Errorf("NewClient() error = %v, want %v", err, tt.wantErr)
}
if (c == nil) != tt.wantNil {
t.Errorf("NewClient() returned nil = %v, want %v", c == nil, tt.wantNil)
}
})
}
}
func TestNewClientFeatureDisabled(t *testing.T) {
// Ensure the feature is disabled
t.Setenv("TS_ENABLE_CUSTOM_MULLVAD", "")
_, err := NewClient("1234567890123456", logger.Discard, nil, nil)
if err != ErrNotEnabled {
t.Errorf("NewClient() error = %v, want %v", err, ErrNotEnabled)
}
}
func TestMaskedAccountNumber(t *testing.T) {
t.Setenv("TS_ENABLE_CUSTOM_MULLVAD", "1")
c, err := NewClient("1234567890123456", logger.Discard, nil, nil)
if err != nil {
t.Fatalf("NewClient() error = %v", err)
}
masked := c.MaskedAccountNumber()
expected := "1234********3456"
if masked != expected {
t.Errorf("MaskedAccountNumber() = %q, want %q", masked, expected)
}
}
// Mock server for testing API calls
type mockMullvadServer struct {
t *testing.T
accountToken string
expiry time.Time
devices []deviceResponse
servers []relayResponse
}
func (m *mockMullvadServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == "POST" && r.URL.Path == "/auth/v1/token":
// Token endpoint
var req struct {
AccountNumber string `json:"account_number"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "bad request", http.StatusBadRequest)
return
}
if req.AccountNumber != "1234567890123456" {
http.Error(w, "invalid account", http.StatusUnauthorized)
return
}
json.NewEncoder(w).Encode(tokenResponse{
AccessToken: "test-token-" + req.AccountNumber,
Expiry: time.Now().Add(24 * time.Hour),
})
case r.Method == "GET" && r.URL.Path == "/public/accounts/v1/1234567890123456":
// Account status endpoint
json.NewEncoder(w).Encode(accountResponse{
Expiry: m.expiry,
})
case r.Method == "GET" && r.URL.Path == "/public/accounts/v1/0000000000000000":
// Invalid account
http.Error(w, "not found", http.StatusNotFound)
case r.Method == "GET" && r.URL.Path == "/accounts/v1/devices":
// List devices
if r.Header.Get("Authorization") == "" {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
json.NewEncoder(w).Encode(m.devices)
case r.Method == "POST" && r.URL.Path == "/accounts/v1/devices":
// Register device
if r.Header.Get("Authorization") == "" {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
var req struct {
Pubkey string `json:"pubkey"`
HijackDNS bool `json:"hijack_dns"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "bad request", http.StatusBadRequest)
return
}
resp := deviceResponse{
ID: "test-device-id",
Pubkey: req.Pubkey,
IPv4Address: "10.64.0.1/32",
IPv6Address: "fc00:bbbb:bbbb:bb01::1/128",
}
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(resp)
case r.URL.Path == "/public/relays/wireguard/v2/":
// Server list - return in the proper relayListResponse format
resp := relayListResponse{
Locations: map[string]locationInfo{
"us-nyc": {Country: "USA", City: "New York City"},
"de-fra": {Country: "Germany", City: "Frankfurt"},
},
}
resp.WireGuard.Relays = m.servers
json.NewEncoder(w).Encode(resp)
default:
http.Error(w, "not found", http.StatusNotFound)
}
}
func TestAuthenticate(t *testing.T) {
t.Setenv("TS_ENABLE_CUSTOM_MULLVAD", "1")
mock := &mockMullvadServer{
t: t,
expiry: time.Now().Add(30 * 24 * time.Hour),
}
server := httptest.NewServer(mock)
defer server.Close()
c, err := NewClient("1234567890123456", t.Logf, nil, nil)
if err != nil {
t.Fatalf("NewClient() error = %v", err)
}
c.apiBase = server.URL
ctx := context.Background()
if err := c.Authenticate(ctx); err != nil {
t.Errorf("Authenticate() error = %v", err)
}
}
func TestGetAccountStatus(t *testing.T) {
t.Setenv("TS_ENABLE_CUSTOM_MULLVAD", "1")
expiry := time.Now().Add(30 * 24 * time.Hour)
mock := &mockMullvadServer{
t: t,
expiry: expiry,
}
server := httptest.NewServer(mock)
defer server.Close()
c, err := NewClient("1234567890123456", t.Logf, nil, nil)
if err != nil {
t.Fatalf("NewClient() error = %v", err)
}
c.apiBase = server.URL
ctx := context.Background()
status, err := c.GetAccountStatus(ctx)
if err != nil {
t.Errorf("GetAccountStatus() error = %v", err)
}
if status.IsExpired {
t.Error("GetAccountStatus() reported account as expired")
}
if status.DaysLeft < 29 || status.DaysLeft > 31 {
t.Errorf("GetAccountStatus() DaysLeft = %d, want ~30", status.DaysLeft)
}
}
func TestRegisterDevice(t *testing.T) {
t.Setenv("TS_ENABLE_CUSTOM_MULLVAD", "1")
mock := &mockMullvadServer{
t: t,
expiry: time.Now().Add(30 * 24 * time.Hour),
}
server := httptest.NewServer(mock)
defer server.Close()
c, err := NewClient("1234567890123456", t.Logf, nil, nil)
if err != nil {
t.Fatalf("NewClient() error = %v", err)
}
c.apiBase = server.URL
ctx := context.Background()
// Generate a test key
nodeKey := key.NewNode()
pubKey := nodeKey.Public()
info, err := c.RegisterDevice(ctx, pubKey)
if err != nil {
t.Fatalf("RegisterDevice() error = %v", err)
}
if info.ID != "test-device-id" {
t.Errorf("RegisterDevice() ID = %s, want test-device-id", info.ID)
}
if !info.IPv4Address.IsValid() {
t.Error("RegisterDevice() returned invalid IPv4 address")
}
}
func TestGetServers(t *testing.T) {
t.Setenv("TS_ENABLE_CUSTOM_MULLVAD", "1")
// Generate a valid base64-encoded public key
nodeKey := key.NewNode()
pubKeyBytes := nodeKey.Public().Raw32()
pubKeyB64 := base64.StdEncoding.EncodeToString(pubKeyBytes[:])
mock := &mockMullvadServer{
t: t,
expiry: time.Now().Add(30 * 24 * time.Hour),
servers: []relayResponse{
{
Hostname: "us-nyc-wg-001",
Location: "us-nyc",
IPv4AddrIn: "193.27.12.1",
IPv6AddrIn: "2a03:1b20:3:f011::a01f",
PublicKey: pubKeyB64,
Active: true,
Owned: true,
},
{
Hostname: "de-fra-wg-001",
Location: "de-fra",
IPv4AddrIn: "185.213.154.1",
IPv6AddrIn: "2a03:1b20:6:f011::a01f",
PublicKey: pubKeyB64,
Active: true,
Owned: true,
},
},
}
server := httptest.NewServer(mock)
defer server.Close()
c, err := NewClient("1234567890123456", t.Logf, nil, nil)
if err != nil {
t.Fatalf("NewClient() error = %v", err)
}
c.apiBase = server.URL
ctx := context.Background()
servers, err := c.GetServers(ctx)
if err != nil {
t.Fatalf("GetServers() error = %v", err)
}
if len(servers) != 2 {
t.Errorf("GetServers() returned %d servers, want 2", len(servers))
}
// Verify first server details
if servers[0].Hostname != "us-nyc-wg-001" {
t.Errorf("GetServers()[0].Hostname = %s, want us-nyc-wg-001", servers[0].Hostname)
}
if servers[0].CountryCode != "us" {
t.Errorf("GetServers()[0].CountryCode = %s, want us", servers[0].CountryCode)
}
if !servers[0].Active {
t.Error("GetServers()[0].Active = false, want true")
}
}
func TestServerCache(t *testing.T) {
t.Setenv("TS_ENABLE_CUSTOM_MULLVAD", "1")
nodeKey := key.NewNode()
pubKeyBytes := nodeKey.Public().Raw32()
pubKeyB64 := base64.StdEncoding.EncodeToString(pubKeyBytes[:])
var callCount atomic.Int32
mock := &mockMullvadServer{
t: t,
expiry: time.Now().Add(30 * 24 * time.Hour),
servers: []relayResponse{
{
Hostname: "us-nyc-wg-001",
Location: "us-nyc",
IPv4AddrIn: "193.27.12.1",
PublicKey: pubKeyB64,
Active: true,
},
},
}
mux := http.NewServeMux()
mux.HandleFunc("/public/relays/wireguard/v2/", func(w http.ResponseWriter, r *http.Request) {
callCount.Add(1)
mock.ServeHTTP(w, r)
})
mux.HandleFunc("/", mock.ServeHTTP)
server := httptest.NewServer(mux)
defer server.Close()
c, err := NewClient("1234567890123456", t.Logf, nil, nil)
if err != nil {
t.Fatalf("NewClient() error = %v", err)
}
c.apiBase = server.URL
ctx := context.Background()
// First call
_, err = c.GetServers(ctx)
if err != nil {
t.Fatalf("GetServers() error = %v", err)
}
// Second call should use cache
_, err = c.GetServers(ctx)
if err != nil {
t.Fatalf("GetServers() error = %v", err)
}
if callCount.Load() != 1 {
t.Errorf("GetServers() made %d API calls, want 1 (cache should be used)", callCount.Load())
}
}

@ -0,0 +1,656 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package ipnlocal
import (
"bytes"
"context"
"crypto/sha256"
"encoding/binary"
"errors"
"fmt"
"io"
"net/http"
"net/netip"
"net/url"
"strconv"
"strings"
"time"
"golang.org/x/net/dns/dnsmessage"
"tailscale.com/health"
"tailscale.com/ipn/ipnlocal/mullvad"
"tailscale.com/net/dnscache"
"tailscale.com/net/dns/publicdns"
"tailscale.com/tailcfg"
"tailscale.com/types/dnstype"
"tailscale.com/types/views"
)
// customMullvadState holds the state for the custom Mullvad integration.
// All fields are protected by LocalBackend.mu.
type customMullvadState struct {
// client is the Mullvad API client
client *mullvad.Client
// peers are the injected Mullvad servers as tailcfg.Node entries
peers []*tailcfg.Node
// deviceInfo contains the registered device info from Mullvad
deviceInfo *mullvad.DeviceInfo
// lastRefresh is when the server list was last refreshed
lastRefresh time.Time
// accountStatus contains the last known account status
accountStatus *mullvad.AccountStatus
}
// configureCustomMullvadLocked sets up personal Mullvad account integration.
// Must be called with b.mu held.
func (b *LocalBackend) configureCustomMullvadLocked(ctx context.Context, accountNumber string) error {
if !mullvad.CustomMullvadEnabled() {
b.logf("mullvad: feature not enabled (TS_ENABLE_CUSTOM_MULLVAD not set)")
return mullvad.ErrNotEnabled
}
// If account number is empty, clear the configuration
if accountNumber == "" {
b.clearCustomMullvadLocked()
return nil
}
// Get or initialize custom mullvad state
state := b.getOrCreateCustomMullvadStateLocked()
// Check if we need to create a new client
if state.client == nil || state.client.AccountNumber() != accountNumber {
// Create a DNS resolver that bypasses MagicDNS for Mullvad API calls.
// This solves the bootstrap problem where MagicDNS routes to Mullvad DNS
// which is unreachable before the Mullvad tunnel is established.
dnsResolver := b.createMullvadBootstrapDNSResolver()
// Use SystemDial to bypass Tailscale tunnel for Mullvad API calls.
// This is necessary because the exit node might be a Mullvad server,
// and we need to reach the Mullvad API directly.
client, err := mullvad.NewClient(accountNumber, b.logf, b.dialer.SystemDial, dnsResolver)
if err != nil {
return fmt.Errorf("creating Mullvad client: %w", err)
}
state.client = client
}
// Authenticate with Mullvad
if err := state.client.Authenticate(ctx); err != nil {
b.setCustomMullvadAuthFailedWarning(err)
return fmt.Errorf("authenticating with Mullvad: %w", err)
}
// Clear any previous auth failure warning
b.clearCustomMullvadAuthFailedWarning()
// Check account status
status, err := state.client.GetAccountStatus(ctx)
if err != nil {
b.setCustomMullvadAuthFailedWarning(err)
return fmt.Errorf("checking account status: %w", err)
}
state.accountStatus = status
if status.IsExpired {
b.setCustomMullvadExpiredWarning(status)
return mullvad.ErrAccountExpired
}
// Update health warnings based on account status
b.updateCustomMullvadHealthWarningsLocked(status)
// Ensure we have a dedicated WireGuard key for Mullvad
if err := b.ensureCustomMullvadKeyLocked(ctx, state); err != nil {
return fmt.Errorf("ensuring Mullvad key: %w", err)
}
// Fetch server list
servers, err := state.client.GetServers(ctx)
if err != nil {
return fmt.Errorf("fetching Mullvad servers: %w", err)
}
// Convert servers to tailcfg.Node entries
peers := make([]*tailcfg.Node, 0, len(servers))
for _, server := range servers {
if !server.Active {
continue
}
node := b.customMullvadServerToNode(server, state.deviceInfo)
peers = append(peers, node)
}
state.peers = peers
state.lastRefresh = time.Now()
// Inject peers into nodeBackend for proper routing and WireGuard config
b.injectCustomMullvadPeersLocked()
b.logf("mullvad: configured %d exit nodes from personal account", len(peers))
return nil
}
// getOrCreateCustomMullvadStateLocked returns the custom Mullvad state,
// creating it if necessary. Must be called with b.mu held.
func (b *LocalBackend) getOrCreateCustomMullvadStateLocked() *customMullvadState {
if b.customMullvadState == nil {
b.customMullvadState = &customMullvadState{}
}
return b.customMullvadState
}
// clearCustomMullvadLocked removes the custom Mullvad configuration.
// Must be called with b.mu held.
func (b *LocalBackend) clearCustomMullvadLocked() {
if b.customMullvadState == nil {
return
}
// Clear health warnings
b.clearCustomMullvadWarnings()
// Clear custom peers from nodeBackend
b.currentNode().SetCustomPeers(nil)
b.customMullvadState = nil
b.logf("mullvad: cleared custom Mullvad configuration")
}
// ensureCustomMullvadKeyLocked registers the Tailscale node's public key with Mullvad.
// We use the Tailscale node key (not a separate key) because wireguard-go only supports
// a single private key per device. The Tailscale key is what wireguard-go will use
// for all WireGuard connections, including to Mullvad servers.
// Must be called with b.mu held.
func (b *LocalBackend) ensureCustomMullvadKeyLocked(ctx context.Context, state *customMullvadState) error {
// Get the current Tailscale node key - this is what wireguard-go uses
priv := b.pm.CurrentPrefs().Persist().PrivateNodeKey()
if priv.IsZero() {
return fmt.Errorf("tailscale node key not available; login required")
}
pubKey := priv.Public()
// Check if we have a previously registered device ID
existingDeviceID, _ := b.loadMullvadDeviceIDLocked()
// Register/lookup device with Mullvad using Tailscale's public key
deviceInfo, err := state.client.RegisterDevice(ctx, pubKey)
if err != nil {
return fmt.Errorf("registering with Mullvad: %w", err)
}
state.deviceInfo = deviceInfo
// Save device ID for future lookups (key changes trigger re-registration)
if existingDeviceID != deviceInfo.ID {
if err := b.saveMullvadDeviceIDLocked(deviceInfo.ID); err != nil {
b.logf("mullvad: warning: failed to save device ID: %v", err)
}
}
b.logf("mullvad: registered Tailscale key with Mullvad (device: %s)", deviceInfo.ID)
return nil
}
// mullvadDeviceIDStateKey is the state key for storing the custom Mullvad device ID.
const mullvadDeviceIDStateKey = "custom-mullvad-device-id"
// saveMullvadDeviceIDLocked saves the Mullvad device ID to storage.
// Must be called with b.mu held.
func (b *LocalBackend) saveMullvadDeviceIDLocked(deviceID string) error {
if err := b.pm.WriteState(mullvadDeviceIDStateKey, []byte(deviceID)); err != nil {
return fmt.Errorf("writing device ID: %w", err)
}
return nil
}
// loadMullvadDeviceIDLocked loads the Mullvad device ID from storage.
// Must be called with b.mu held.
func (b *LocalBackend) loadMullvadDeviceIDLocked() (string, error) {
deviceID, err := b.pm.Store().ReadState(mullvadDeviceIDStateKey)
if err != nil {
return "", err
}
return string(deviceID), nil
}
// customMullvadServerToNode converts a Mullvad server to a tailcfg.Node.
func (b *LocalBackend) customMullvadServerToNode(server mullvad.Server, deviceInfo *mullvad.DeviceInfo) *tailcfg.Node {
// Generate a stable node ID from the hostname
h := sha256.Sum256([]byte("custom-mullvad-" + server.Hostname))
nodeID := tailcfg.NodeID(binary.BigEndian.Uint64(h[:8]))
endpoints := make([]netip.AddrPort, 0, 2)
if server.IPv4.IsValid() {
endpoints = append(endpoints, netip.AddrPortFrom(server.IPv4, server.Port))
}
if server.IPv6.IsValid() {
endpoints = append(endpoints, netip.AddrPortFrom(server.IPv6, server.Port))
}
// Use the server's public IPs as addresses for display purposes.
// This gives each server a unique IP in the exit-node list.
addresses := make([]netip.Prefix, 0, 2)
if server.IPv4.IsValid() {
addresses = append(addresses, netip.PrefixFrom(server.IPv4, 32))
}
if server.IPv6.IsValid() {
addresses = append(addresses, netip.PrefixFrom(server.IPv6, 128))
}
// AllowedIPs for exit node: 0.0.0.0/0 and ::/0
allowedIPs := []netip.Prefix{
netip.MustParsePrefix("0.0.0.0/0"),
netip.MustParsePrefix("::/0"),
}
// Mullvad DNS servers
dnsResolvers := []*dnstype.Resolver{
{Addr: "10.64.0.1"}, // Mullvad internal DNS
}
return &tailcfg.Node{
ID: nodeID,
StableID: tailcfg.StableNodeID("custom-mullvad-" + server.Hostname),
Name: server.Hostname + ".mullvad.custom.",
Key: server.PublicKey,
IsWireGuardOnly: true,
Endpoints: endpoints,
AllowedIPs: allowedIPs,
Addresses: addresses,
Hostinfo: (&tailcfg.Hostinfo{
Location: &tailcfg.Location{
Country: server.CountryName,
CountryCode: strings.ToUpper(server.CountryCode),
City: server.CityName,
CityCode: strings.ToUpper(server.CityCode),
Priority: 50, // Lower than Tailscale-managed Mullvad
},
}).View(),
ExitNodeDNSResolvers: dnsResolvers,
CapMap: tailcfg.NodeCapMap{
tailcfg.NodeAttrSuggestExitNode: nil,
tailcfg.NodeAttrCustomMullvad: nil,
},
}
}
// getCustomMullvadPeers returns the custom Mullvad peers as NodeViews.
// Thread-safe.
func (b *LocalBackend) getCustomMullvadPeers() []tailcfg.NodeView {
b.mu.Lock()
defer b.mu.Unlock()
return b.getCustomMullvadPeersLocked()
}
// getCustomMullvadPeersLocked returns the custom Mullvad peers as NodeViews.
// Must be called with b.mu held.
func (b *LocalBackend) getCustomMullvadPeersLocked() []tailcfg.NodeView {
if b.customMullvadState == nil || len(b.customMullvadState.peers) == 0 {
return nil
}
peers := make([]tailcfg.NodeView, len(b.customMullvadState.peers))
for i, p := range b.customMullvadState.peers {
peers[i] = p.View()
}
return peers
}
// isCustomMullvadNode reports whether the node is a custom Mullvad node.
func isCustomMullvadNode(node tailcfg.NodeView) bool {
if !node.Valid() {
return false
}
return node.CapMap().Contains(tailcfg.NodeAttrCustomMullvad)
}
// isCustomMullvadNodeByStableID reports whether the StableNodeID is a custom Mullvad node.
func isCustomMullvadNodeByStableID(id tailcfg.StableNodeID) bool {
return strings.HasPrefix(string(id), "custom-mullvad-")
}
// Health warning helpers
func (b *LocalBackend) updateCustomMullvadHealthWarningsLocked(status *mullvad.AccountStatus) {
if status == nil {
return
}
// Clear existing warnings first
b.clearCustomMullvadWarnings()
if status.IsExpired {
b.setCustomMullvadExpiredWarning(status)
} else if status.DaysLeft <= 7 {
b.setCustomMullvadExpiringWarning(status)
}
}
func (b *LocalBackend) setCustomMullvadExpiringWarning(status *mullvad.AccountStatus) {
b.health.SetUnhealthy(health.CustomMullvadExpiringWarnable, health.Args{
health.ArgDaysRemaining: strconv.Itoa(status.DaysLeft),
})
}
func (b *LocalBackend) setCustomMullvadExpiredWarning(status *mullvad.AccountStatus) {
b.health.SetUnhealthy(health.CustomMullvadExpiredWarnable, health.Args{
health.ArgExpiryDate: status.Expiry.Format("2006-01-02"),
})
}
func (b *LocalBackend) setCustomMullvadAuthFailedWarning(err error) {
b.health.SetUnhealthy(health.CustomMullvadAuthFailedWarnable, health.Args{
health.ArgError: err.Error(),
})
}
func (b *LocalBackend) clearCustomMullvadAuthFailedWarning() {
b.health.SetHealthy(health.CustomMullvadAuthFailedWarnable)
}
func (b *LocalBackend) clearCustomMullvadWarnings() {
b.health.SetHealthy(health.CustomMullvadExpiringWarnable)
b.health.SetHealthy(health.CustomMullvadExpiredWarnable)
b.health.SetHealthy(health.CustomMullvadAuthFailedWarnable)
}
// refreshCustomMullvadLocked re-fetches the Mullvad server list.
// Must be called with b.mu held.
func (b *LocalBackend) refreshCustomMullvadLocked(ctx context.Context) error {
if b.customMullvadState == nil || b.customMullvadState.client == nil {
return nil
}
state := b.customMullvadState
// Check account status
status, err := state.client.GetAccountStatus(ctx)
if err != nil {
b.setCustomMullvadAuthFailedWarning(err)
return fmt.Errorf("checking account status: %w", err)
}
state.accountStatus = status
b.clearCustomMullvadAuthFailedWarning()
// Update health warnings
b.updateCustomMullvadHealthWarningsLocked(status)
if status.IsExpired {
return mullvad.ErrAccountExpired
}
// Fetch updated server list
servers, err := state.client.GetServers(ctx)
if err != nil {
return fmt.Errorf("fetching Mullvad servers: %w", err)
}
// Convert servers to tailcfg.Node entries
peers := make([]*tailcfg.Node, 0, len(servers))
for _, server := range servers {
if !server.Active {
continue
}
node := b.customMullvadServerToNode(server, state.deviceInfo)
peers = append(peers, node)
}
state.peers = peers
state.lastRefresh = time.Now()
// Inject updated peers into nodeBackend
b.injectCustomMullvadPeersLocked()
b.logf("mullvad: refreshed server list, %d exit nodes available", len(peers))
return nil
}
// CustomMullvadStatus returns information about the custom Mullvad configuration.
type CustomMullvadStatus struct {
Configured bool
AccountExpiry time.Time
DaysRemaining int
ServerCount int
DeviceIPv4 netip.Addr
DeviceIPv6 netip.Addr
LastRefresh time.Time
}
// GetCustomMullvadStatus returns the current custom Mullvad status.
// Thread-safe.
func (b *LocalBackend) GetCustomMullvadStatus() CustomMullvadStatus {
b.mu.Lock()
defer b.mu.Unlock()
if b.customMullvadState == nil || b.customMullvadState.client == nil {
return CustomMullvadStatus{}
}
state := b.customMullvadState
status := CustomMullvadStatus{
Configured: true,
ServerCount: len(state.peers),
LastRefresh: state.lastRefresh,
}
if state.accountStatus != nil {
status.AccountExpiry = state.accountStatus.Expiry
status.DaysRemaining = state.accountStatus.DaysLeft
}
if state.deviceInfo != nil {
status.DeviceIPv4 = state.deviceInfo.IPv4Address
status.DeviceIPv6 = state.deviceInfo.IPv6Address
}
return status
}
// injectCustomMullvadPeers adds custom Mullvad peers to the peer list.
// This is used during netmap processing.
func injectCustomMullvadPeers(peers []tailcfg.NodeView, customPeers []tailcfg.NodeView) []tailcfg.NodeView {
if len(customPeers) == 0 {
return peers
}
result := make([]tailcfg.NodeView, 0, len(peers)+len(customPeers))
result = append(result, peers...)
result = append(result, customPeers...)
return result
}
// filterCustomMullvadPeers returns only the custom Mullvad peers from a peer list.
func filterCustomMullvadPeers(peers views.Slice[tailcfg.NodeView]) []tailcfg.NodeView {
var result []tailcfg.NodeView
for i := range peers.Len() {
if isCustomMullvadNode(peers.At(i)) {
result = append(result, peers.At(i))
}
}
return result
}
// ConfigureCustomMullvad configures the custom Mullvad account.
// This is the public method called by the local API.
func (b *LocalBackend) ConfigureCustomMullvad(ctx context.Context, accountNumber string) error {
b.mu.Lock()
defer b.mu.Unlock()
return b.configureCustomMullvadLocked(ctx, accountNumber)
}
// RefreshCustomMullvad refreshes the custom Mullvad configuration.
// This is the public method called by the local API.
func (b *LocalBackend) RefreshCustomMullvad(ctx context.Context) error {
b.mu.Lock()
defer b.mu.Unlock()
return b.refreshCustomMullvadLocked(ctx)
}
// injectCustomMullvadPeersLocked updates the nodeBackend with custom Mullvad peers
// so they are included in peer lookups and WireGuard configuration.
// Must be called with b.mu held.
func (b *LocalBackend) injectCustomMullvadPeersLocked() {
if b.customMullvadState == nil || len(b.customMullvadState.peers) == 0 {
b.currentNode().SetCustomPeers(nil)
return
}
peerViews := make([]tailcfg.NodeView, len(b.customMullvadState.peers))
for i, p := range b.customMullvadState.peers {
peerViews[i] = p.View()
}
b.currentNode().SetCustomPeers(peerViews)
}
// createMullvadBootstrapDNSResolver creates a DNS resolver for Mullvad API calls
// that bypasses MagicDNS by using public DoH servers as a fallback.
// This solves the bootstrap problem where MagicDNS routes to Mullvad DNS
// which is unreachable before the Mullvad tunnel is established.
//
// Pattern: Same as net/dns/resolver/forwarder.go getKnownDoHClientForProvider()
func (b *LocalBackend) createMullvadBootstrapDNSResolver() *dnscache.Resolver {
return &dnscache.Resolver{
Forward: dnscache.Get().Forward,
UseLastGood: true,
LookupIPFallback: b.resolveMullvadAPIViaDoH,
Logf: b.logf,
}
}
// resolveMullvadAPIViaDoH resolves hostnames using public DoH providers.
// Uses the existing publicdns package to get known DoH provider IPs,
// avoiding hardcoded DNS servers in this code.
// Uses SystemDial to ensure the DoH request bypasses the Tailscale tunnel.
func (b *LocalBackend) resolveMullvadAPIViaDoH(ctx context.Context, host string) ([]netip.Addr, error) {
// Try multiple DoH providers in order (same providers used by net/dns/resolver/forwarder.go)
dohProviders := []string{
"https://cloudflare-dns.com/dns-query",
"https://dns.google/dns-query",
"https://dns.quad9.net/dns-query",
}
var lastErr error
for _, dohBase := range dohProviders {
addrs, err := b.resolveViaDoHProvider(ctx, host, dohBase)
if err == nil && len(addrs) > 0 {
return addrs, nil
}
lastErr = err
b.logf("mullvad: DoH provider %s failed: %v", dohBase, err)
}
if lastErr != nil {
return nil, fmt.Errorf("all DoH providers failed, last error: %w", lastErr)
}
return nil, errors.New("no DoH providers available")
}
// resolveViaDoHProvider resolves a hostname using a specific DoH provider.
// Uses publicdns.DoHIPsOfBase() to get known IPs for the provider.
func (b *LocalBackend) resolveViaDoHProvider(ctx context.Context, host, dohBase string) ([]netip.Addr, error) {
// Get known IPs for this DoH provider from publicdns package.
// This avoids hardcoding IPs - we use Tailscale's existing known DoH provider list.
allIPs := publicdns.DoHIPsOfBase(dohBase)
if len(allIPs) == 0 {
return nil, fmt.Errorf("no known IPs for DoH provider %s", dohBase)
}
// Parse the DoH URL to get the hostname
dohURL, err := url.Parse(dohBase)
if err != nil {
return nil, err
}
// Create HTTP client that dials the DoH provider directly by IP.
// Pattern from net/dns/resolver/forwarder.go:438-442
dohResolver := &dnscache.Resolver{
SingleHost: dohURL.Hostname(),
SingleHostStaticResult: allIPs,
}
tr := http.DefaultTransport.(*http.Transport).Clone()
tr.DialContext = dnscache.Dialer(b.dialer.SystemDial, dohResolver)
hc := &http.Client{Transport: tr, Timeout: 10 * time.Second}
// Build DNS query for host (A and AAAA records)
addrs, err := b.doDoHQuery(ctx, hc, dohBase, host, dnsmessage.TypeA)
if err != nil {
b.logf("mullvad: DoH A query to %s for %s failed: %v", dohBase, host, err)
}
// Also try AAAA
addrs6, err6 := b.doDoHQuery(ctx, hc, dohBase, host, dnsmessage.TypeAAAA)
if err6 == nil {
addrs = append(addrs, addrs6...)
}
if len(addrs) == 0 {
return nil, fmt.Errorf("no addresses returned from %s for %s", dohBase, host)
}
return addrs, nil
}
// doDoHQuery performs a single DoH query for the given record type.
func (b *LocalBackend) doDoHQuery(ctx context.Context, hc *http.Client, dohBase, host string, qtype dnsmessage.Type) ([]netip.Addr, error) {
// Build DNS query packet
var msg dnsmessage.Message
msg.Header.ID = uint16(time.Now().UnixNano())
msg.Header.RecursionDesired = true
msg.Questions = []dnsmessage.Question{
{Name: dnsmessage.MustNewName(host + "."), Type: qtype, Class: dnsmessage.ClassINET},
}
packet, err := msg.Pack()
if err != nil {
return nil, err
}
// Send DoH request (RFC 8484)
req, err := http.NewRequestWithContext(ctx, "POST", dohBase, bytes.NewReader(packet))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/dns-message")
req.Header.Set("Accept", "application/dns-message")
resp, err := hc.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
return nil, fmt.Errorf("DoH returned status %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
// Parse DNS response
var respMsg dnsmessage.Message
if err := respMsg.Unpack(body); err != nil {
return nil, err
}
var addrs []netip.Addr
for _, ans := range respMsg.Answers {
switch r := ans.Body.(type) {
case *dnsmessage.AResource:
addrs = append(addrs, netip.AddrFrom4(r.A))
case *dnsmessage.AAAAResource:
addrs = append(addrs, netip.AddrFrom16(r.AAAA))
}
}
return addrs, nil
}

@ -8,6 +8,7 @@ import (
"context"
"net/netip"
"slices"
"strings"
"sync"
"sync/atomic"
@ -15,6 +16,7 @@ import (
"tailscale.com/feature/buildfeatures"
"tailscale.com/ipn"
"tailscale.com/net/dns"
"tailscale.com/net/dns/publicdns"
"tailscale.com/net/tsaddr"
"tailscale.com/syncs"
"tailscale.com/tailcfg"
@ -106,6 +108,11 @@ type nodeBackend struct {
// nodeByAddr maps nodes' own addresses (excluding subnet routes) to node IDs.
// It is mutated in place (with mu held) and must not escape the [nodeBackend].
nodeByAddr map[netip.Addr]tailcfg.NodeID
// customPeers holds externally-injected peers (e.g., custom Mullvad exit nodes)
// that should be included alongside control-plane peers for lookups and WireGuard config.
// It is mutated in place (with mu held) and must not escape the [nodeBackend].
customPeers map[tailcfg.NodeID]tailcfg.NodeView
}
func newNodeBackend(ctx context.Context, logf logger.Logf, bus *eventbus.Bus) *nodeBackend {
@ -222,8 +229,14 @@ func (nb *nodeBackend) NodeByID(id tailcfg.NodeID) (_ tailcfg.NodeView, ok bool)
return self, true
}
}
n, ok := nb.peers[id]
return n, ok
if n, ok := nb.peers[id]; ok {
return n, true
}
// Also check custom peers (e.g., custom Mullvad)
if n, ok := nb.customPeers[id]; ok {
return n, true
}
return tailcfg.NodeView{}, false
}
func (nb *nodeBackend) PeerByStableID(id tailcfg.StableNodeID) (_ tailcfg.NodeView, ok bool) {
@ -234,6 +247,12 @@ func (nb *nodeBackend) PeerByStableID(id tailcfg.StableNodeID) (_ tailcfg.NodeVi
return n, true
}
}
// Also check custom peers (e.g., custom Mullvad)
for _, n := range nb.customPeers {
if n.StableID() == id {
return n, true
}
}
return tailcfg.NodeView{}, false
}
@ -252,7 +271,21 @@ func (nb *nodeBackend) UserByID(id tailcfg.UserID) (_ tailcfg.UserProfileView, o
func (nb *nodeBackend) Peers() []tailcfg.NodeView {
nb.mu.Lock()
defer nb.mu.Unlock()
return slicesx.MapValues(nb.peers)
total := len(nb.peers) + len(nb.customPeers)
if total == 0 {
return nil
}
// Combine control-plane peers with custom peers
result := make([]tailcfg.NodeView, 0, total)
for _, p := range nb.peers {
result = append(result, p)
}
for _, p := range nb.customPeers {
result = append(result, p)
}
return result
}
func (nb *nodeBackend) PeersForTest() []tailcfg.NodeView {
@ -421,7 +454,17 @@ func (nb *nodeBackend) netMapWithPeers() *netmap.NetworkMap {
return nil
}
nm := ptr.To(*nb.netMap) // shallow clone
nm.Peers = slicesx.MapValues(nb.peers)
// Combine control-plane peers with custom peers (e.g., custom Mullvad)
totalPeers := len(nb.peers) + len(nb.customPeers)
nm.Peers = make([]tailcfg.NodeView, 0, totalPeers)
for _, p := range nb.peers {
nm.Peers = append(nm.Peers, p)
}
for _, p := range nb.customPeers {
nm.Peers = append(nm.Peers, p)
}
slices.SortFunc(nm.Peers, func(a, b tailcfg.NodeView) int {
return cmp.Compare(a.ID(), b.ID())
})
@ -434,20 +477,64 @@ func (nb *nodeBackend) SetNetMap(nm *netmap.NetworkMap) {
nb.netMap = nm
nb.updateNodeByAddrLocked()
nb.updatePeersLocked()
nv := magicsock.NodeViewsUpdate{}
if nm != nil {
nv.SelfNode = nm.SelfNode
nv.Peers = nm.Peers
nb.derpMapViewPub.Publish(nm.DERPMap.View())
} else {
nb.derpMapViewPub.Publish(tailcfg.DERPMapView{})
}
// Publish combined peers (control + custom) to magicsock
nb.publishNodeViewsUpdateLocked()
}
// SetCustomPeers sets externally-injected peers (e.g., custom Mullvad exit nodes)
// that should be included in peer lookups and WireGuard configuration alongside
// control-plane peers. Call with nil to clear custom peers.
func (nb *nodeBackend) SetCustomPeers(peers []tailcfg.NodeView) {
nb.mu.Lock()
defer nb.mu.Unlock()
if len(peers) == 0 {
nb.customPeers = nil
} else {
nb.customPeers = make(map[tailcfg.NodeID]tailcfg.NodeView, len(peers))
for _, p := range peers {
nb.customPeers[p.ID()] = p
}
}
// Rebuild address index to include custom peers
nb.updateNodeByAddrLocked()
// Notify magicsock about the updated peer set (combined control + custom peers).
// This is critical for WireGuard-only peers like custom Mullvad exit nodes,
// as magicsock needs to know about them to handle ParseEndpoint correctly.
nb.publishNodeViewsUpdateLocked()
}
// publishNodeViewsUpdateLocked publishes a NodeViewsUpdate with the combined set
// of control-plane peers and custom peers. Must be called with nb.mu held.
func (nb *nodeBackend) publishNodeViewsUpdateLocked() {
nv := magicsock.NodeViewsUpdate{}
if nb.netMap != nil {
nv.SelfNode = nb.netMap.SelfNode
// Combine control-plane peers with custom peers
totalPeers := len(nb.netMap.Peers) + len(nb.customPeers)
if totalPeers > 0 {
combinedPeers := make([]tailcfg.NodeView, 0, totalPeers)
combinedPeers = append(combinedPeers, nb.netMap.Peers...)
for _, p := range nb.customPeers {
combinedPeers = append(combinedPeers, p)
}
nv.Peers = combinedPeers
}
}
nb.nodeViewsPub.Publish(nv)
}
func (nb *nodeBackend) updateNodeByAddrLocked() {
nm := nb.netMap
if nm == nil {
if nm == nil && len(nb.customPeers) == 0 {
nb.nodeByAddr = nil
return
}
@ -467,10 +554,17 @@ func (nb *nodeBackend) updateNodeByAddrLocked() {
}
}
}
if nm.SelfNode.Valid() {
addNode(nm.SelfNode)
// Add self node and control-plane peers
if nm != nil {
if nm.SelfNode.Valid() {
addNode(nm.SelfNode)
}
for _, p := range nm.Peers {
addNode(p)
}
}
for _, p := range nm.Peers {
// Add custom peers (e.g., custom Mullvad exit nodes)
for _, p := range nb.customPeers {
addNode(p)
}
// Third pass, actually delete the unwanted items.
@ -566,7 +660,29 @@ func (nb *nodeBackend) setFilter(f *filter.Filter) {
func (nb *nodeBackend) dnsConfigForNetmap(prefs ipn.PrefsView, selfExpired bool, versionOS string) *dns.Config {
nb.mu.Lock()
defer nb.mu.Unlock()
return dnsConfigForNetmap(nb.netMap, nb.peers, prefs, selfExpired, nb.logf, versionOS)
// Check if we should skip exit node DNS for custom Mullvad.
// If exit node is custom-mullvad-* but customPeers is empty,
// the Mullvad init hasn't completed yet - use fallback DNS instead
// of unreachable Mullvad DNS.
exitNodeID := prefs.ExitNodeID()
skipExitNodeDNS := strings.HasPrefix(string(exitNodeID), "custom-mullvad-") && len(nb.customPeers) == 0
// Combine control-plane peers with custom peers for DNS config.
// This is necessary for WireGuard-only exit nodes (like custom Mullvad)
// so that wireguardExitNodeDNSResolvers can find the exit node.
allPeers := nb.peers
if len(nb.customPeers) > 0 {
allPeers = make(map[tailcfg.NodeID]tailcfg.NodeView, len(nb.peers)+len(nb.customPeers))
for id, p := range nb.peers {
allPeers[id] = p
}
for id, p := range nb.customPeers {
allPeers[id] = p
}
}
return dnsConfigForNetmap(nb.netMap, allPeers, prefs, selfExpired, nb.logf, versionOS, skipExitNodeDNS)
}
func (nb *nodeBackend) exitNodeCanProxyDNS(exitNodeID tailcfg.StableNodeID) (dohURL string, ok bool) {
@ -575,7 +691,20 @@ func (nb *nodeBackend) exitNodeCanProxyDNS(exitNodeID tailcfg.StableNodeID) (doh
}
nb.mu.Lock()
defer nb.mu.Unlock()
return exitNodeCanProxyDNS(nb.netMap, nb.peers, exitNodeID)
// Combine control-plane peers with custom peers.
allPeers := nb.peers
if len(nb.customPeers) > 0 {
allPeers = make(map[tailcfg.NodeID]tailcfg.NodeView, len(nb.peers)+len(nb.customPeers))
for id, p := range nb.peers {
allPeers[id] = p
}
for id, p := range nb.customPeers {
allPeers[id] = p
}
}
return exitNodeCanProxyDNS(nb.netMap, allPeers, exitNodeID)
}
// ready signals that [LocalBackend] has completed the switch to this [nodeBackend]
@ -673,7 +802,12 @@ func useWithExitNodeRoutes(routes map[string][]*dnstype.Resolver) map[string][]*
//
// The versionOS is a Tailscale-style version ("iOS", "macOS") and not
// a runtime.GOOS.
func dnsConfigForNetmap(nm *netmap.NetworkMap, peers map[tailcfg.NodeID]tailcfg.NodeView, prefs ipn.PrefsView, selfExpired bool, logf logger.Logf, versionOS string) *dns.Config {
//
// If skipExitNodeDNS is true, the function will not use the exit node's
// DNS resolvers. This is used when the exit node is a custom Mullvad node
// that hasn't been initialized yet - using its DNS would fail because
// the Mullvad tunnel isn't established.
func dnsConfigForNetmap(nm *netmap.NetworkMap, peers map[tailcfg.NodeID]tailcfg.NodeView, prefs ipn.PrefsView, selfExpired bool, logf logger.Logf, versionOS string, skipExitNodeDNS bool) *dns.Config {
if nm == nil {
return nil
}
@ -813,7 +947,8 @@ func dnsConfigForNetmap(nm *netmap.NetworkMap, peers map[tailcfg.NodeID]tailcfg.
// to run a DoH DNS proxy, then send all our DNS traffic through it,
// unless we find resolvers with UseWithExitNode set, in which case we use that.
if buildfeatures.HasUseExitNode {
if dohURL, ok := exitNodeCanProxyDNS(nm, peers, prefs.ExitNodeID()); ok {
dohURL, canProxy := exitNodeCanProxyDNS(nm, peers, prefs.ExitNodeID())
if canProxy {
filtered := useWithExitNodeResolvers(nm.DNS.Resolvers)
if len(filtered) > 0 {
addDefault(filtered)
@ -833,10 +968,23 @@ func dnsConfigForNetmap(nm *netmap.NetworkMap, peers map[tailcfg.NodeID]tailcfg.
// node resolvers, use those as the default.
if len(nm.DNS.Resolvers) > 0 {
addDefault(nm.DNS.Resolvers)
} else if buildfeatures.HasUseExitNode {
} else if buildfeatures.HasUseExitNode && !skipExitNodeDNS {
if resolvers, ok := wireguardExitNodeDNSResolvers(nm, peers, prefs.ExitNodeID()); ok {
addDefault(resolvers)
}
} else if skipExitNodeDNS {
// Custom Mullvad exit node is not initialized yet. Use public DNS
// as a temporary fallback to avoid using stale/invalid DNS from
// the base system config (e.g., OrbStack's 0.250.250.200).
var fallback []*dnstype.Resolver
for _, ip := range publicdns.DoHIPsOfBase("https://dns.google/dns-query") {
fallback = append(fallback, &dnstype.Resolver{Addr: ip.String()})
}
if len(fallback) == 0 {
// Hardcode fallback if publicdns returns empty
fallback = []*dnstype.Resolver{{Addr: "8.8.8.8"}, {Addr: "8.8.4.4"}}
}
addDefault(fallback)
}
// Add split DNS routes, with no regard to exit node configuration.

@ -0,0 +1,167 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package localapi
import (
"encoding/json"
"errors"
"io"
"net/http"
"time"
"tailscale.com/ipn/ipnlocal/mullvad"
"tailscale.com/util/httpm"
)
func init() {
// Register Mullvad endpoints unconditionally.
// The feature flag is checked at runtime in each handler.
Register("mullvad/status", (*Handler).serveMullvadStatus)
Register("mullvad/configure", (*Handler).serveMullvadConfigure)
Register("mullvad/refresh", (*Handler).serveMullvadRefresh)
}
// MullvadStatusResponse is the response for the mullvad/status endpoint.
type MullvadStatusResponse struct {
Configured bool `json:"configured"`
AccountExpiry time.Time `json:"accountExpiry,omitempty"`
DaysRemaining int `json:"daysRemaining,omitempty"`
ServerCount int `json:"serverCount,omitempty"`
DeviceIPv4 string `json:"deviceIPv4,omitempty"`
DeviceIPv6 string `json:"deviceIPv6,omitempty"`
LastRefresh time.Time `json:"lastRefresh,omitempty"`
}
// serveMullvadStatus returns the current custom Mullvad configuration status.
func (h *Handler) serveMullvadStatus(w http.ResponseWriter, r *http.Request) {
if !mullvad.CustomMullvadEnabled() {
http.Error(w, "custom Mullvad support not enabled", http.StatusNotImplemented)
return
}
if r.Method != httpm.GET {
http.Error(w, "only GET allowed", http.StatusMethodNotAllowed)
return
}
if !h.PermitRead {
http.Error(w, "read access denied", http.StatusForbidden)
return
}
status := h.b.GetCustomMullvadStatus()
resp := MullvadStatusResponse{
Configured: status.Configured,
AccountExpiry: status.AccountExpiry,
DaysRemaining: status.DaysRemaining,
ServerCount: status.ServerCount,
LastRefresh: status.LastRefresh,
}
if status.DeviceIPv4.IsValid() {
resp.DeviceIPv4 = status.DeviceIPv4.String()
}
if status.DeviceIPv6.IsValid() {
resp.DeviceIPv6 = status.DeviceIPv6.String()
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
// MullvadConfigureRequest is the request for the mullvad/configure endpoint.
type MullvadConfigureRequest struct {
AccountNumber string `json:"accountNumber"`
}
// MullvadConfigureResponse is the response for the mullvad/configure endpoint.
type MullvadConfigureResponse struct {
Success bool `json:"success"`
Error string `json:"error,omitempty"`
AccountExpiry time.Time `json:"accountExpiry,omitempty"`
ServerCount int `json:"serverCount,omitempty"`
}
// serveMullvadConfigure configures the custom Mullvad account.
func (h *Handler) serveMullvadConfigure(w http.ResponseWriter, r *http.Request) {
if !mullvad.CustomMullvadEnabled() {
http.Error(w, "custom Mullvad support not enabled", http.StatusNotImplemented)
return
}
if r.Method != httpm.POST {
http.Error(w, "only POST allowed", http.StatusMethodNotAllowed)
return
}
if !h.PermitWrite {
http.Error(w, "write access denied", http.StatusForbidden)
return
}
var req MullvadConfigureRequest
if err := json.NewDecoder(io.LimitReader(r.Body, 1<<16)).Decode(&req); err != nil { // 64KB limit
http.Error(w, "invalid request body: "+err.Error(), http.StatusBadRequest)
return
}
// Use LocalBackend to configure Mullvad
// This is done via preferences to ensure persistence
if err := h.b.ConfigureCustomMullvad(r.Context(), req.AccountNumber); err != nil {
switch {
case errors.Is(err, mullvad.ErrInvalidAccount):
http.Error(w, err.Error(), http.StatusBadRequest)
case errors.Is(err, mullvad.ErrAccountExpired):
http.Error(w, err.Error(), http.StatusUnauthorized)
case errors.Is(err, mullvad.ErrNotEnabled):
http.Error(w, err.Error(), http.StatusNotImplemented)
default:
WriteErrorJSON(w, err)
}
return
}
status := h.b.GetCustomMullvadStatus()
resp := MullvadConfigureResponse{
Success: true,
AccountExpiry: status.AccountExpiry,
ServerCount: status.ServerCount,
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
// serveMullvadRefresh forces a refresh of the Mullvad server list.
func (h *Handler) serveMullvadRefresh(w http.ResponseWriter, r *http.Request) {
if !mullvad.CustomMullvadEnabled() {
http.Error(w, "custom Mullvad support not enabled", http.StatusNotImplemented)
return
}
if r.Method != httpm.POST {
http.Error(w, "only POST allowed", http.StatusMethodNotAllowed)
return
}
if !h.PermitWrite {
http.Error(w, "write access denied", http.StatusForbidden)
return
}
if err := h.b.RefreshCustomMullvad(r.Context()); err != nil {
WriteErrorJSON(w, err)
return
}
status := h.b.GetCustomMullvadStatus()
resp := MullvadStatusResponse{
Configured: status.Configured,
AccountExpiry: status.AccountExpiry,
DaysRemaining: status.DaysRemaining,
ServerCount: status.ServerCount,
LastRefresh: status.LastRefresh,
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
// Ensure ipnlocal.LocalBackend has the necessary methods.
// These are defined in mullvad_integration.go.
// Note: We use context.Context from the standard library, not an interface type.
// This is just a compile-time check that the methods exist.

@ -291,6 +291,12 @@ type Prefs struct {
// non-nil.
RelayServerStaticEndpoints []netip.AddrPort `json:",omitempty"`
// CustomMullvadAccount is the 16-digit Mullvad account number for
// "Bring Your Own Mullvad Account" (BYOMA) integration. When set,
// Tailscale will register its WireGuard key with Mullvad and expose
// Mullvad servers as exit nodes.
CustomMullvadAccount string `json:",omitempty"`
// AllowSingleHosts was a legacy field that was always true
// for the past 4.5 years. It controlled whether Tailscale
// peers got /32 or /128 routes for each other.
@ -386,6 +392,7 @@ type MaskedPrefs struct {
DriveSharesSet bool `json:",omitempty"`
RelayServerPortSet bool `json:",omitempty"`
RelayServerStaticEndpointsSet bool `json:",omitzero"`
CustomMullvadAccountSet bool `json:",omitempty"`
}
// SetsInternal reports whether mp has any of the Internal*Set field bools set
@ -693,7 +700,8 @@ func (p *Prefs) Equals(p2 *Prefs) bool {
slices.EqualFunc(p.DriveShares, p2.DriveShares, drive.SharesEqual) &&
p.NetfilterKind == p2.NetfilterKind &&
compareUint16Ptrs(p.RelayServerPort, p2.RelayServerPort) &&
slices.Equal(p.RelayServerStaticEndpoints, p2.RelayServerStaticEndpoints)
slices.Equal(p.RelayServerStaticEndpoints, p2.RelayServerStaticEndpoints) &&
p.CustomMullvadAccount == p2.CustomMullvadAccount
}
func (au AutoUpdatePrefs) Pretty() string {

@ -70,6 +70,7 @@ func TestPrefsEqual(t *testing.T) {
"DriveShares",
"RelayServerPort",
"RelayServerStaticEndpoints",
"CustomMullvadAccount",
"AllowSingleHosts",
"Persist",
}
@ -390,6 +391,16 @@ func TestPrefsEqual(t *testing.T) {
&Prefs{RelayServerStaticEndpoints: aps("[2001:db8::1]:40000", "192.0.2.1:40000")},
false,
},
{
&Prefs{CustomMullvadAccount: "1234567890123456"},
&Prefs{CustomMullvadAccount: "1234567890123456"},
true,
},
{
&Prefs{CustomMullvadAccount: "1234567890123456"},
&Prefs{CustomMullvadAccount: "6543210987654321"},
false,
},
}
for i, tt := range tests {
got := tt.a.Equals(tt.b)

@ -2610,6 +2610,10 @@ const (
// NodeAttrSuggestExitNodeUI allows the currently suggested exit node to appear in the client GUI.
NodeAttrSuggestExitNodeUI NodeCapability = "suggest-exit-node-ui"
// NodeAttrCustomMullvad marks a node as a custom Mullvad exit node
// (part of the "Bring Your Own Mullvad Account" feature).
NodeAttrCustomMullvad NodeCapability = "custom-mullvad"
// NodeAttrUserDialUseRoutes makes UserDial use either the peer dialer or the system dialer,
// depending on the destination address and the configured routes. When present, it also makes
// the DNS forwarder use UserDial instead of SystemDial when dialing resolvers.

@ -23,4 +23,7 @@ const (
HealthWarnableTestWarnable = "test-warnable"
HealthWarnableApplyDiskConfig = "apply-disk-config"
HealthWarnableWarmingUp = "warming-up"
HealthWarnableCustomMullvadExpiring = "custom-mullvad-expiring"
HealthWarnableCustomMullvadExpired = "custom-mullvad-expired"
HealthWarnableCustomMullvadAuthFailed = "custom-mullvad-auth-failed"
)

Loading…
Cancel
Save