mirror of https://github.com/tailscale/tailscale/
ipn,cmd/tailscale/cli: add Bring Your Own Mullvad Account (BYOMA) support
Updates #cleanup Signed-off-by: Karthik Vinayan <karthikdoestech@gmail.com>
parent
5f34f14e14
commit
254a5f0213
@ -0,0 +1,738 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// Package mullvad provides a client for the Mullvad VPN API,
|
||||
// enabling "Bring Your Own Mullvad Account" functionality.
|
||||
package mullvad
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go4.org/mem"
|
||||
"tailscale.com/envknob"
|
||||
"tailscale.com/net/dnscache"
|
||||
"tailscale.com/net/netx"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/logger"
|
||||
)
|
||||
|
||||
// CustomMullvadEnabled reports whether custom Mullvad account support is enabled.
|
||||
func CustomMullvadEnabled() bool {
|
||||
return envknob.Bool("TS_ENABLE_CUSTOM_MULLVAD")
|
||||
}
|
||||
|
||||
const (
|
||||
// defaultAPIBase is the base URL for the Mullvad API.
|
||||
defaultAPIBase = "https://api.mullvad.net"
|
||||
|
||||
// defaultWireGuardPort is the default WireGuard port for Mullvad servers.
|
||||
defaultWireGuardPort = 51820
|
||||
|
||||
// serverCacheExpiry is how long to cache the server list.
|
||||
serverCacheExpiry = 6 * time.Hour
|
||||
|
||||
// tokenRefreshMargin is how long before expiry to refresh the token.
|
||||
tokenRefreshMargin = 5 * time.Minute
|
||||
|
||||
// maxAPIResponseSize is the maximum size of API response bodies we'll read.
|
||||
maxAPIResponseSize = 1 << 20 // 1MB
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrAccountExpired is returned when the Mullvad account has expired.
|
||||
ErrAccountExpired = errors.New("mullvad account has expired")
|
||||
|
||||
// ErrInvalidAccount is returned when the account number is invalid.
|
||||
ErrInvalidAccount = errors.New("invalid mullvad account number")
|
||||
|
||||
// ErrDeviceLimitReached is returned when the device limit is reached.
|
||||
ErrDeviceLimitReached = errors.New("mullvad device limit reached")
|
||||
|
||||
// ErrNotEnabled is returned when custom Mullvad support is not enabled.
|
||||
ErrNotEnabled = errors.New("custom mullvad support not enabled")
|
||||
)
|
||||
|
||||
// Server represents a Mullvad WireGuard server.
|
||||
type Server struct {
|
||||
Hostname string // "us-nyc-wg-001"
|
||||
IPv4 netip.Addr // Server's IPv4 address
|
||||
IPv6 netip.Addr // Server's IPv6 address
|
||||
PublicKey key.NodePublic // Server's WireGuard public key
|
||||
Port uint16 // WireGuard port (usually 51820)
|
||||
CountryCode string // ISO 3166-1 alpha-2 ("us")
|
||||
CountryName string // "USA"
|
||||
CityCode string // "nyc"
|
||||
CityName string // "New York City"
|
||||
Active bool // Whether the server is currently active
|
||||
Owned bool // Whether Mullvad owns this server
|
||||
}
|
||||
|
||||
// AccountStatus contains account information.
|
||||
type AccountStatus struct {
|
||||
Expiry time.Time
|
||||
DaysLeft int
|
||||
IsExpired bool
|
||||
}
|
||||
|
||||
// DeviceInfo represents a registered device.
|
||||
type DeviceInfo struct {
|
||||
ID string
|
||||
PublicKey key.NodePublic
|
||||
IPv4Address netip.Addr
|
||||
IPv6Address netip.Addr
|
||||
Created time.Time
|
||||
}
|
||||
|
||||
// API request types for JSON marshaling.
|
||||
type (
|
||||
tokenRequest struct {
|
||||
AccountNumber string `json:"account_number"`
|
||||
}
|
||||
deviceRegisterRequest struct {
|
||||
Pubkey string `json:"pubkey"`
|
||||
HijackDNS bool `json:"hijack_dns"`
|
||||
}
|
||||
deviceRotateKeyRequest struct {
|
||||
Pubkey string `json:"pubkey"`
|
||||
}
|
||||
)
|
||||
|
||||
// Client handles communication with the Mullvad API.
|
||||
type Client struct {
|
||||
httpClient *http.Client
|
||||
logf logger.Logf
|
||||
apiBase string
|
||||
|
||||
mu sync.Mutex
|
||||
accountNum string
|
||||
accessToken string
|
||||
tokenExpiry time.Time
|
||||
deviceID string
|
||||
deviceInfo *DeviceInfo
|
||||
|
||||
// Cached data
|
||||
servers []Server
|
||||
serversFetched time.Time
|
||||
}
|
||||
|
||||
// DialContextFunc is a function that dials a network connection.
|
||||
// This allows the caller to provide a custom dialer that bypasses the Tailscale tunnel.
|
||||
type DialContextFunc func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
|
||||
// NewClient creates a new Mullvad API client.
|
||||
// dialFunc is used for HTTP connections (e.g., to bypass Tailscale tunnel).
|
||||
// dnsResolver is used for DNS lookups (to bypass MagicDNS during bootstrap).
|
||||
// If dialFunc is nil, the default system dialer is used.
|
||||
// If dnsResolver is nil, DNS lookups use the system resolver directly.
|
||||
func NewClient(accountNumber string, logf logger.Logf, dialFunc DialContextFunc, dnsResolver *dnscache.Resolver) (*Client, error) {
|
||||
if !CustomMullvadEnabled() {
|
||||
return nil, ErrNotEnabled
|
||||
}
|
||||
if !isValidAccountNumber(accountNumber) {
|
||||
return nil, ErrInvalidAccount
|
||||
}
|
||||
|
||||
httpClient := &http.Client{Timeout: 30 * time.Second}
|
||||
if dialFunc != nil {
|
||||
tr := http.DefaultTransport.(*http.Transport).Clone()
|
||||
if dnsResolver != nil {
|
||||
// Wrap dialFunc with DNS caching that bypasses MagicDNS.
|
||||
// This solves the bootstrap problem where MagicDNS routes to Mullvad DNS
|
||||
// which is unreachable before the Mullvad tunnel is established.
|
||||
tr.DialContext = dnscache.Dialer(netx.DialFunc(dialFunc), dnsResolver)
|
||||
} else {
|
||||
tr.DialContext = dialFunc
|
||||
}
|
||||
httpClient.Transport = tr
|
||||
}
|
||||
|
||||
return &Client{
|
||||
httpClient: httpClient,
|
||||
logf: logf,
|
||||
apiBase: defaultAPIBase,
|
||||
accountNum: accountNumber,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// isValidAccountNumber checks if the account number is valid (16 digits).
|
||||
func isValidAccountNumber(num string) bool {
|
||||
if len(num) != 16 {
|
||||
return false
|
||||
}
|
||||
for _, c := range num {
|
||||
if c < '0' || c > '9' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// tokenResponse represents the response from the token endpoint.
|
||||
type tokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
Expiry time.Time `json:"expiry"`
|
||||
}
|
||||
|
||||
// Authenticate obtains an access token from Mullvad.
|
||||
func (c *Client) Authenticate(ctx context.Context) error {
|
||||
// Fast path: check if we already have a valid token.
|
||||
c.mu.Lock()
|
||||
if c.accessToken != "" && time.Now().Add(tokenRefreshMargin).Before(c.tokenExpiry) {
|
||||
c.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
// Copy values needed for the request while holding the lock.
|
||||
accountNum := c.accountNum
|
||||
apiBase := c.apiBase
|
||||
c.mu.Unlock()
|
||||
|
||||
// Do HTTP request without holding the lock to avoid blocking concurrent operations.
|
||||
reqBodyJSON, err := json.Marshal(tokenRequest{AccountNumber: accountNum})
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshaling auth request: %w", err)
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", apiBase+"/auth/v1/token", bytes.NewReader(reqBodyJSON))
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating auth request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("auth request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusNotFound {
|
||||
return ErrInvalidAccount
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, maxAPIResponseSize))
|
||||
return fmt.Errorf("auth failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var tokenResp tokenResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||
return fmt.Errorf("decoding auth response: %w", err)
|
||||
}
|
||||
|
||||
// Reacquire lock to update state.
|
||||
c.mu.Lock()
|
||||
c.accessToken = tokenResp.AccessToken
|
||||
c.tokenExpiry = tokenResp.Expiry
|
||||
c.mu.Unlock()
|
||||
|
||||
c.logf("mullvad: authenticated, token expires at %v", tokenResp.Expiry)
|
||||
return nil
|
||||
}
|
||||
|
||||
// accountResponse represents the response from the public account endpoint.
|
||||
// Endpoint: GET https://api.mullvad.net/public/accounts/v1/{account}
|
||||
// Example response: {"id":"1234567890123456","expiry":"2026-02-01T20:01:32+00:00"}
|
||||
type accountResponse struct {
|
||||
ID string `json:"id"`
|
||||
Expiry time.Time `json:"expiry"`
|
||||
}
|
||||
|
||||
// GetAccountStatus checks account expiry using the public API endpoint.
|
||||
func (c *Client) GetAccountStatus(ctx context.Context) (*AccountStatus, error) {
|
||||
// Use the public API endpoint that doesn't require auth
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", c.apiBase+"/public/accounts/v1/"+c.accountNum, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating account status request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("account status request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
return nil, ErrInvalidAccount
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, maxAPIResponseSize))
|
||||
return nil, fmt.Errorf("account status failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var accResp accountResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&accResp); err != nil {
|
||||
return nil, fmt.Errorf("decoding account response: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
daysLeft := int(accResp.Expiry.Sub(now).Hours() / 24)
|
||||
if daysLeft < 0 {
|
||||
daysLeft = 0
|
||||
}
|
||||
|
||||
return &AccountStatus{
|
||||
Expiry: accResp.Expiry,
|
||||
DaysLeft: daysLeft,
|
||||
IsExpired: now.After(accResp.Expiry),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// deviceResponse represents a device from the API.
|
||||
type deviceResponse struct {
|
||||
ID string `json:"id"`
|
||||
Pubkey string `json:"pubkey"`
|
||||
IPv4Address string `json:"ipv4_address"`
|
||||
IPv6Address string `json:"ipv6_address"`
|
||||
Created string `json:"created"`
|
||||
}
|
||||
|
||||
// RegisterDevice registers a WireGuard public key with Mullvad.
|
||||
func (c *Client) RegisterDevice(ctx context.Context, pubkey key.NodePublic) (*DeviceInfo, error) {
|
||||
if err := c.Authenticate(ctx); err != nil {
|
||||
return nil, fmt.Errorf("authentication failed: %w", err)
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
token := c.accessToken
|
||||
c.mu.Unlock()
|
||||
|
||||
// Encode the public key as base64
|
||||
pubkeyBytes := pubkey.Raw32()
|
||||
pubkeyB64 := base64.StdEncoding.EncodeToString(pubkeyBytes[:])
|
||||
|
||||
reqBodyJSON, err := json.Marshal(deviceRegisterRequest{Pubkey: pubkeyB64, HijackDNS: false})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshaling register request: %w", err)
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", c.apiBase+"/accounts/v1/devices", bytes.NewReader(reqBodyJSON))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating register request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("register request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusConflict {
|
||||
// Device already registered, try to find it
|
||||
return c.findExistingDevice(ctx, pubkey)
|
||||
}
|
||||
if resp.StatusCode == http.StatusForbidden {
|
||||
return nil, ErrDeviceLimitReached
|
||||
}
|
||||
if resp.StatusCode == http.StatusBadRequest {
|
||||
// Check if it's a "pubkey already in use" error
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, maxAPIResponseSize))
|
||||
if strings.Contains(string(body), "PUBKEY_IN_USE") || strings.Contains(string(body), "already in use") {
|
||||
// Key already registered, try to find it
|
||||
c.logf("mullvad: key already registered, looking up existing device")
|
||||
return c.findExistingDevice(ctx, pubkey)
|
||||
}
|
||||
return nil, fmt.Errorf("register failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, maxAPIResponseSize))
|
||||
return nil, fmt.Errorf("register failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var devResp deviceResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&devResp); err != nil {
|
||||
return nil, fmt.Errorf("decoding register response: %w", err)
|
||||
}
|
||||
|
||||
info, err := parseDeviceResponse(&devResp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
c.deviceID = info.ID
|
||||
c.deviceInfo = info
|
||||
c.mu.Unlock()
|
||||
|
||||
c.logf("mullvad: registered device %s with IPv4 %v", info.ID, info.IPv4Address)
|
||||
return info, nil
|
||||
}
|
||||
|
||||
func parseDeviceResponse(resp *deviceResponse) (*DeviceInfo, error) {
|
||||
// Parse public key
|
||||
pubkeyBytes, err := base64.StdEncoding.DecodeString(resp.Pubkey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decoding pubkey: %w", err)
|
||||
}
|
||||
if len(pubkeyBytes) != 32 {
|
||||
return nil, fmt.Errorf("invalid pubkey length: %d", len(pubkeyBytes))
|
||||
}
|
||||
pubkey := key.NodePublicFromRaw32(mem.B(pubkeyBytes))
|
||||
|
||||
// Parse IPv4 address (strip the /32 suffix if present)
|
||||
ipv4Str := strings.TrimSuffix(resp.IPv4Address, "/32")
|
||||
ipv4, err := netip.ParseAddr(ipv4Str)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing IPv4: %w", err)
|
||||
}
|
||||
|
||||
// Parse IPv6 address (strip the /128 suffix if present)
|
||||
ipv6Str := strings.TrimSuffix(resp.IPv6Address, "/128")
|
||||
ipv6, err := netip.ParseAddr(ipv6Str)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing IPv6: %w", err)
|
||||
}
|
||||
|
||||
return &DeviceInfo{
|
||||
ID: resp.ID,
|
||||
PublicKey: pubkey,
|
||||
IPv4Address: ipv4,
|
||||
IPv6Address: ipv6,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// findExistingDevice finds a device by its public key.
|
||||
func (c *Client) findExistingDevice(ctx context.Context, pubkey key.NodePublic) (*DeviceInfo, error) {
|
||||
devices, err := c.ListDevices(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listing devices: %w", err)
|
||||
}
|
||||
|
||||
pubkeyBytes := pubkey.Raw32()
|
||||
pubkeyB64 := base64.StdEncoding.EncodeToString(pubkeyBytes[:])
|
||||
|
||||
for _, dev := range devices {
|
||||
devKeyBytes := dev.PublicKey.Raw32()
|
||||
devKeyB64 := base64.StdEncoding.EncodeToString(devKeyBytes[:])
|
||||
if devKeyB64 == pubkeyB64 {
|
||||
c.mu.Lock()
|
||||
c.deviceID = dev.ID
|
||||
c.deviceInfo = dev
|
||||
c.mu.Unlock()
|
||||
return dev, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("device with key not found")
|
||||
}
|
||||
|
||||
// ListDevices returns all devices registered to the account.
|
||||
func (c *Client) ListDevices(ctx context.Context) ([]*DeviceInfo, error) {
|
||||
if err := c.Authenticate(ctx); err != nil {
|
||||
return nil, fmt.Errorf("authentication failed: %w", err)
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
token := c.accessToken
|
||||
c.mu.Unlock()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", c.apiBase+"/accounts/v1/devices", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating list devices request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list devices request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, maxAPIResponseSize))
|
||||
return nil, fmt.Errorf("list devices failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var devicesResp []deviceResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&devicesResp); err != nil {
|
||||
return nil, fmt.Errorf("decoding devices response: %w", err)
|
||||
}
|
||||
|
||||
devices := make([]*DeviceInfo, 0, len(devicesResp))
|
||||
for _, d := range devicesResp {
|
||||
info, err := parseDeviceResponse(&d)
|
||||
if err != nil {
|
||||
c.logf("mullvad: skipping device %s: %v", d.ID, err)
|
||||
continue
|
||||
}
|
||||
devices = append(devices, info)
|
||||
}
|
||||
|
||||
return devices, nil
|
||||
}
|
||||
|
||||
// relayResponse represents a relay from the /public/relays/wireguard/v2/ API.
|
||||
// Example: {"hostname":"us-nyc-wg-001","location":"us-nyc","active":true,"owned":false,
|
||||
// "provider":"M247","ipv4_addr_in":"146.70.198.66","include_in_country":true,
|
||||
// "weight":100,"public_key":"TUCaQc26/R6AGpkDUr8A8ytUs/e5+UVlIVujbuBwlzI=",
|
||||
// "ipv6_addr_in":"2a0d:5600:9:c::f001"}
|
||||
type relayResponse struct {
|
||||
Hostname string `json:"hostname"`
|
||||
Location string `json:"location"`
|
||||
Active bool `json:"active"`
|
||||
Owned bool `json:"owned"`
|
||||
Provider string `json:"provider"`
|
||||
IPv4AddrIn string `json:"ipv4_addr_in"`
|
||||
IPv6AddrIn string `json:"ipv6_addr_in"`
|
||||
PublicKey string `json:"public_key"`
|
||||
IncludeInCountry bool `json:"include_in_country"`
|
||||
Weight int `json:"weight"`
|
||||
}
|
||||
|
||||
// locationInfo represents a location from the relay list.
|
||||
type locationInfo struct {
|
||||
Country string `json:"country"`
|
||||
City string `json:"city"`
|
||||
Latitude float64 `json:"latitude"`
|
||||
Longitude float64 `json:"longitude"`
|
||||
}
|
||||
|
||||
// relayListResponse represents the response from /public/relays/wireguard/v2/.
|
||||
type relayListResponse struct {
|
||||
Locations map[string]locationInfo `json:"locations"`
|
||||
WireGuard struct {
|
||||
Relays []relayResponse `json:"relays"`
|
||||
} `json:"wireguard"`
|
||||
}
|
||||
|
||||
// GetServers fetches the list of available Mullvad WireGuard servers.
|
||||
func (c *Client) GetServers(ctx context.Context) ([]Server, error) {
|
||||
c.mu.Lock()
|
||||
// Check cache
|
||||
if len(c.servers) > 0 && time.Since(c.serversFetched) < serverCacheExpiry {
|
||||
servers := c.servers
|
||||
c.mu.Unlock()
|
||||
return servers, nil
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
// Fetch the relay list from the v2 API
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", c.apiBase+"/public/relays/wireguard/v2/", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating servers request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("servers request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, maxAPIResponseSize))
|
||||
return nil, fmt.Errorf("servers request failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var relayList relayListResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&relayList); err != nil {
|
||||
return nil, fmt.Errorf("decoding servers response: %w", err)
|
||||
}
|
||||
|
||||
servers := make([]Server, 0, len(relayList.WireGuard.Relays))
|
||||
for _, r := range relayList.WireGuard.Relays {
|
||||
// Parse IPv4
|
||||
ipv4, err := netip.ParseAddr(r.IPv4AddrIn)
|
||||
if err != nil {
|
||||
c.logf("mullvad: skipping server %s: invalid IPv4: %v", r.Hostname, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse IPv6
|
||||
var ipv6 netip.Addr
|
||||
if r.IPv6AddrIn != "" {
|
||||
ipv6, err = netip.ParseAddr(r.IPv6AddrIn)
|
||||
if err != nil {
|
||||
c.logf("mullvad: server %s has invalid IPv6: %v", r.Hostname, err)
|
||||
// Don't skip, IPv6 is optional
|
||||
}
|
||||
}
|
||||
|
||||
// Parse public key
|
||||
pubkeyBytes, err := base64.StdEncoding.DecodeString(r.PublicKey)
|
||||
if err != nil {
|
||||
c.logf("mullvad: skipping server %s: invalid pubkey: %v", r.Hostname, err)
|
||||
continue
|
||||
}
|
||||
if len(pubkeyBytes) != 32 {
|
||||
c.logf("mullvad: skipping server %s: wrong pubkey length", r.Hostname)
|
||||
continue
|
||||
}
|
||||
pubkey := key.NodePublicFromRaw32(mem.B(pubkeyBytes))
|
||||
|
||||
// Get location info
|
||||
loc := relayList.Locations[r.Location]
|
||||
// Location code format is like "us-nyc", split to get country and city codes
|
||||
countryCode := ""
|
||||
cityCode := ""
|
||||
if parts := strings.SplitN(r.Location, "-", 2); len(parts) == 2 {
|
||||
countryCode = parts[0]
|
||||
cityCode = parts[1]
|
||||
}
|
||||
|
||||
servers = append(servers, Server{
|
||||
Hostname: r.Hostname,
|
||||
IPv4: ipv4,
|
||||
IPv6: ipv6,
|
||||
PublicKey: pubkey,
|
||||
Port: defaultWireGuardPort,
|
||||
CountryCode: countryCode,
|
||||
CountryName: loc.Country,
|
||||
CityCode: cityCode,
|
||||
CityName: loc.City,
|
||||
Active: r.Active,
|
||||
Owned: r.Owned,
|
||||
})
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
c.servers = servers
|
||||
c.serversFetched = time.Now()
|
||||
c.mu.Unlock()
|
||||
|
||||
c.logf("mullvad: fetched %d WireGuard servers", len(servers))
|
||||
return servers, nil
|
||||
}
|
||||
|
||||
// GetDeviceInfo returns the current device info.
|
||||
func (c *Client) GetDeviceInfo() *DeviceInfo {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.deviceInfo
|
||||
}
|
||||
|
||||
// AccountNumber returns the account number.
|
||||
func (c *Client) AccountNumber() string {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.accountNum
|
||||
}
|
||||
|
||||
// MaskedAccountNumber returns the account number with middle digits masked.
|
||||
func (c *Client) MaskedAccountNumber() string {
|
||||
c.mu.Lock()
|
||||
num := c.accountNum
|
||||
c.mu.Unlock()
|
||||
|
||||
if len(num) != 16 {
|
||||
return "invalid"
|
||||
}
|
||||
return num[:4] + "********" + num[12:]
|
||||
}
|
||||
|
||||
// RotateKey rotates the registered WireGuard key.
|
||||
func (c *Client) RotateKey(ctx context.Context, newPubkey key.NodePublic) error {
|
||||
if err := c.Authenticate(ctx); err != nil {
|
||||
return fmt.Errorf("authentication failed: %w", err)
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
token := c.accessToken
|
||||
deviceID := c.deviceID
|
||||
c.mu.Unlock()
|
||||
|
||||
if deviceID == "" {
|
||||
return fmt.Errorf("no device registered")
|
||||
}
|
||||
|
||||
pubkeyBytes := newPubkey.Raw32()
|
||||
pubkeyB64 := base64.StdEncoding.EncodeToString(pubkeyBytes[:])
|
||||
|
||||
reqBodyJSON, err := json.Marshal(deviceRotateKeyRequest{Pubkey: pubkeyB64})
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshaling rotate key request: %w", err)
|
||||
}
|
||||
endpoint := fmt.Sprintf("%s/accounts/v1/devices/%s/pubkey", c.apiBase, url.PathEscape(deviceID))
|
||||
req, err := http.NewRequestWithContext(ctx, "PUT", endpoint, bytes.NewReader(reqBodyJSON))
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating rotate key request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("rotate key request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, maxAPIResponseSize))
|
||||
return fmt.Errorf("rotate key failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
c.logf("mullvad: rotated key for device %s", deviceID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveDevice removes the registered device.
|
||||
func (c *Client) RemoveDevice(ctx context.Context) error {
|
||||
if err := c.Authenticate(ctx); err != nil {
|
||||
return fmt.Errorf("authentication failed: %w", err)
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
token := c.accessToken
|
||||
deviceID := c.deviceID
|
||||
c.mu.Unlock()
|
||||
|
||||
if deviceID == "" {
|
||||
return nil // Nothing to remove
|
||||
}
|
||||
|
||||
endpoint := fmt.Sprintf("%s/accounts/v1/devices/%s", c.apiBase, url.PathEscape(deviceID))
|
||||
req, err := http.NewRequestWithContext(ctx, "DELETE", endpoint, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating remove device request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("remove device request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusNotFound {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, maxAPIResponseSize))
|
||||
return fmt.Errorf("remove device failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
c.deviceID = ""
|
||||
c.deviceInfo = nil
|
||||
c.mu.Unlock()
|
||||
|
||||
c.logf("mullvad: removed device")
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifyConnection checks if traffic is routing through Mullvad.
|
||||
func (c *Client) VerifyConnection(ctx context.Context) (bool, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://am.i.mullvad.net/connected", nil)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("creating verify request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("verify request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, maxAPIResponseSize))
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("reading verify response: %w", err)
|
||||
}
|
||||
|
||||
return strings.Contains(string(body), "You are connected to Mullvad"), nil
|
||||
}
|
||||
@ -0,0 +1,390 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package mullvad
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/logger"
|
||||
)
|
||||
|
||||
func TestIsValidAccountNumber(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want bool
|
||||
}{
|
||||
{"valid", "1234567890123456", true},
|
||||
{"too short", "12345678901234", false},
|
||||
{"too long", "12345678901234567", false},
|
||||
{"empty", "", false},
|
||||
{"with letters", "123456789012345a", false},
|
||||
{"with spaces", "1234567890123 56", false},
|
||||
{"with dashes", "1234-5678-9012-3456", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := isValidAccountNumber(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("isValidAccountNumber(%q) = %v, want %v", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewClient(t *testing.T) {
|
||||
// Note: This test requires TS_ENABLE_CUSTOM_MULLVAD=1 to pass
|
||||
t.Setenv("TS_ENABLE_CUSTOM_MULLVAD", "1")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
account string
|
||||
wantErr error
|
||||
wantNil bool
|
||||
}{
|
||||
{"valid account", "1234567890123456", nil, false},
|
||||
{"invalid account short", "123456", ErrInvalidAccount, true},
|
||||
{"invalid account letters", "abcdef1234567890", ErrInvalidAccount, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c, err := NewClient(tt.account, logger.Discard, nil, nil)
|
||||
if err != tt.wantErr {
|
||||
t.Errorf("NewClient() error = %v, want %v", err, tt.wantErr)
|
||||
}
|
||||
if (c == nil) != tt.wantNil {
|
||||
t.Errorf("NewClient() returned nil = %v, want %v", c == nil, tt.wantNil)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewClientFeatureDisabled(t *testing.T) {
|
||||
// Ensure the feature is disabled
|
||||
t.Setenv("TS_ENABLE_CUSTOM_MULLVAD", "")
|
||||
|
||||
_, err := NewClient("1234567890123456", logger.Discard, nil, nil)
|
||||
if err != ErrNotEnabled {
|
||||
t.Errorf("NewClient() error = %v, want %v", err, ErrNotEnabled)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaskedAccountNumber(t *testing.T) {
|
||||
t.Setenv("TS_ENABLE_CUSTOM_MULLVAD", "1")
|
||||
|
||||
c, err := NewClient("1234567890123456", logger.Discard, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewClient() error = %v", err)
|
||||
}
|
||||
|
||||
masked := c.MaskedAccountNumber()
|
||||
expected := "1234********3456"
|
||||
if masked != expected {
|
||||
t.Errorf("MaskedAccountNumber() = %q, want %q", masked, expected)
|
||||
}
|
||||
}
|
||||
|
||||
// Mock server for testing API calls
|
||||
type mockMullvadServer struct {
|
||||
t *testing.T
|
||||
accountToken string
|
||||
expiry time.Time
|
||||
devices []deviceResponse
|
||||
servers []relayResponse
|
||||
}
|
||||
|
||||
func (m *mockMullvadServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.Method == "POST" && r.URL.Path == "/auth/v1/token":
|
||||
// Token endpoint
|
||||
var req struct {
|
||||
AccountNumber string `json:"account_number"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if req.AccountNumber != "1234567890123456" {
|
||||
http.Error(w, "invalid account", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
json.NewEncoder(w).Encode(tokenResponse{
|
||||
AccessToken: "test-token-" + req.AccountNumber,
|
||||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
|
||||
case r.Method == "GET" && r.URL.Path == "/public/accounts/v1/1234567890123456":
|
||||
// Account status endpoint
|
||||
json.NewEncoder(w).Encode(accountResponse{
|
||||
Expiry: m.expiry,
|
||||
})
|
||||
|
||||
case r.Method == "GET" && r.URL.Path == "/public/accounts/v1/0000000000000000":
|
||||
// Invalid account
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
|
||||
case r.Method == "GET" && r.URL.Path == "/accounts/v1/devices":
|
||||
// List devices
|
||||
if r.Header.Get("Authorization") == "" {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
json.NewEncoder(w).Encode(m.devices)
|
||||
|
||||
case r.Method == "POST" && r.URL.Path == "/accounts/v1/devices":
|
||||
// Register device
|
||||
if r.Header.Get("Authorization") == "" {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
var req struct {
|
||||
Pubkey string `json:"pubkey"`
|
||||
HijackDNS bool `json:"hijack_dns"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
resp := deviceResponse{
|
||||
ID: "test-device-id",
|
||||
Pubkey: req.Pubkey,
|
||||
IPv4Address: "10.64.0.1/32",
|
||||
IPv6Address: "fc00:bbbb:bbbb:bb01::1/128",
|
||||
}
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
|
||||
case r.URL.Path == "/public/relays/wireguard/v2/":
|
||||
// Server list - return in the proper relayListResponse format
|
||||
resp := relayListResponse{
|
||||
Locations: map[string]locationInfo{
|
||||
"us-nyc": {Country: "USA", City: "New York City"},
|
||||
"de-fra": {Country: "Germany", City: "Frankfurt"},
|
||||
},
|
||||
}
|
||||
resp.WireGuard.Relays = m.servers
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
|
||||
default:
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthenticate(t *testing.T) {
|
||||
t.Setenv("TS_ENABLE_CUSTOM_MULLVAD", "1")
|
||||
|
||||
mock := &mockMullvadServer{
|
||||
t: t,
|
||||
expiry: time.Now().Add(30 * 24 * time.Hour),
|
||||
}
|
||||
server := httptest.NewServer(mock)
|
||||
defer server.Close()
|
||||
|
||||
c, err := NewClient("1234567890123456", t.Logf, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewClient() error = %v", err)
|
||||
}
|
||||
c.apiBase = server.URL
|
||||
|
||||
ctx := context.Background()
|
||||
if err := c.Authenticate(ctx); err != nil {
|
||||
t.Errorf("Authenticate() error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAccountStatus(t *testing.T) {
|
||||
t.Setenv("TS_ENABLE_CUSTOM_MULLVAD", "1")
|
||||
|
||||
expiry := time.Now().Add(30 * 24 * time.Hour)
|
||||
mock := &mockMullvadServer{
|
||||
t: t,
|
||||
expiry: expiry,
|
||||
}
|
||||
server := httptest.NewServer(mock)
|
||||
defer server.Close()
|
||||
|
||||
c, err := NewClient("1234567890123456", t.Logf, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewClient() error = %v", err)
|
||||
}
|
||||
c.apiBase = server.URL
|
||||
|
||||
ctx := context.Background()
|
||||
status, err := c.GetAccountStatus(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("GetAccountStatus() error = %v", err)
|
||||
}
|
||||
|
||||
if status.IsExpired {
|
||||
t.Error("GetAccountStatus() reported account as expired")
|
||||
}
|
||||
if status.DaysLeft < 29 || status.DaysLeft > 31 {
|
||||
t.Errorf("GetAccountStatus() DaysLeft = %d, want ~30", status.DaysLeft)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterDevice(t *testing.T) {
|
||||
t.Setenv("TS_ENABLE_CUSTOM_MULLVAD", "1")
|
||||
|
||||
mock := &mockMullvadServer{
|
||||
t: t,
|
||||
expiry: time.Now().Add(30 * 24 * time.Hour),
|
||||
}
|
||||
server := httptest.NewServer(mock)
|
||||
defer server.Close()
|
||||
|
||||
c, err := NewClient("1234567890123456", t.Logf, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewClient() error = %v", err)
|
||||
}
|
||||
c.apiBase = server.URL
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Generate a test key
|
||||
nodeKey := key.NewNode()
|
||||
pubKey := nodeKey.Public()
|
||||
|
||||
info, err := c.RegisterDevice(ctx, pubKey)
|
||||
if err != nil {
|
||||
t.Fatalf("RegisterDevice() error = %v", err)
|
||||
}
|
||||
|
||||
if info.ID != "test-device-id" {
|
||||
t.Errorf("RegisterDevice() ID = %s, want test-device-id", info.ID)
|
||||
}
|
||||
if !info.IPv4Address.IsValid() {
|
||||
t.Error("RegisterDevice() returned invalid IPv4 address")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetServers(t *testing.T) {
|
||||
t.Setenv("TS_ENABLE_CUSTOM_MULLVAD", "1")
|
||||
|
||||
// Generate a valid base64-encoded public key
|
||||
nodeKey := key.NewNode()
|
||||
pubKeyBytes := nodeKey.Public().Raw32()
|
||||
pubKeyB64 := base64.StdEncoding.EncodeToString(pubKeyBytes[:])
|
||||
|
||||
mock := &mockMullvadServer{
|
||||
t: t,
|
||||
expiry: time.Now().Add(30 * 24 * time.Hour),
|
||||
servers: []relayResponse{
|
||||
{
|
||||
Hostname: "us-nyc-wg-001",
|
||||
Location: "us-nyc",
|
||||
IPv4AddrIn: "193.27.12.1",
|
||||
IPv6AddrIn: "2a03:1b20:3:f011::a01f",
|
||||
PublicKey: pubKeyB64,
|
||||
Active: true,
|
||||
Owned: true,
|
||||
},
|
||||
{
|
||||
Hostname: "de-fra-wg-001",
|
||||
Location: "de-fra",
|
||||
IPv4AddrIn: "185.213.154.1",
|
||||
IPv6AddrIn: "2a03:1b20:6:f011::a01f",
|
||||
PublicKey: pubKeyB64,
|
||||
Active: true,
|
||||
Owned: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
server := httptest.NewServer(mock)
|
||||
defer server.Close()
|
||||
|
||||
c, err := NewClient("1234567890123456", t.Logf, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewClient() error = %v", err)
|
||||
}
|
||||
c.apiBase = server.URL
|
||||
|
||||
ctx := context.Background()
|
||||
servers, err := c.GetServers(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("GetServers() error = %v", err)
|
||||
}
|
||||
|
||||
if len(servers) != 2 {
|
||||
t.Errorf("GetServers() returned %d servers, want 2", len(servers))
|
||||
}
|
||||
|
||||
// Verify first server details
|
||||
if servers[0].Hostname != "us-nyc-wg-001" {
|
||||
t.Errorf("GetServers()[0].Hostname = %s, want us-nyc-wg-001", servers[0].Hostname)
|
||||
}
|
||||
if servers[0].CountryCode != "us" {
|
||||
t.Errorf("GetServers()[0].CountryCode = %s, want us", servers[0].CountryCode)
|
||||
}
|
||||
if !servers[0].Active {
|
||||
t.Error("GetServers()[0].Active = false, want true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerCache(t *testing.T) {
|
||||
t.Setenv("TS_ENABLE_CUSTOM_MULLVAD", "1")
|
||||
|
||||
nodeKey := key.NewNode()
|
||||
pubKeyBytes := nodeKey.Public().Raw32()
|
||||
pubKeyB64 := base64.StdEncoding.EncodeToString(pubKeyBytes[:])
|
||||
|
||||
var callCount atomic.Int32
|
||||
mock := &mockMullvadServer{
|
||||
t: t,
|
||||
expiry: time.Now().Add(30 * 24 * time.Hour),
|
||||
servers: []relayResponse{
|
||||
{
|
||||
Hostname: "us-nyc-wg-001",
|
||||
Location: "us-nyc",
|
||||
IPv4AddrIn: "193.27.12.1",
|
||||
PublicKey: pubKeyB64,
|
||||
Active: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/public/relays/wireguard/v2/", func(w http.ResponseWriter, r *http.Request) {
|
||||
callCount.Add(1)
|
||||
mock.ServeHTTP(w, r)
|
||||
})
|
||||
mux.HandleFunc("/", mock.ServeHTTP)
|
||||
|
||||
server := httptest.NewServer(mux)
|
||||
defer server.Close()
|
||||
|
||||
c, err := NewClient("1234567890123456", t.Logf, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewClient() error = %v", err)
|
||||
}
|
||||
c.apiBase = server.URL
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// First call
|
||||
_, err = c.GetServers(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("GetServers() error = %v", err)
|
||||
}
|
||||
|
||||
// Second call should use cache
|
||||
_, err = c.GetServers(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("GetServers() error = %v", err)
|
||||
}
|
||||
|
||||
if callCount.Load() != 1 {
|
||||
t.Errorf("GetServers() made %d API calls, want 1 (cache should be used)", callCount.Load())
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,656 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package ipnlocal
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/dns/dnsmessage"
|
||||
"tailscale.com/health"
|
||||
"tailscale.com/ipn/ipnlocal/mullvad"
|
||||
"tailscale.com/net/dnscache"
|
||||
"tailscale.com/net/dns/publicdns"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/dnstype"
|
||||
"tailscale.com/types/views"
|
||||
)
|
||||
|
||||
// customMullvadState holds the state for the custom Mullvad integration.
|
||||
// All fields are protected by LocalBackend.mu.
|
||||
type customMullvadState struct {
|
||||
// client is the Mullvad API client
|
||||
client *mullvad.Client
|
||||
|
||||
// peers are the injected Mullvad servers as tailcfg.Node entries
|
||||
peers []*tailcfg.Node
|
||||
|
||||
// deviceInfo contains the registered device info from Mullvad
|
||||
deviceInfo *mullvad.DeviceInfo
|
||||
|
||||
// lastRefresh is when the server list was last refreshed
|
||||
lastRefresh time.Time
|
||||
|
||||
// accountStatus contains the last known account status
|
||||
accountStatus *mullvad.AccountStatus
|
||||
}
|
||||
|
||||
// configureCustomMullvadLocked sets up personal Mullvad account integration.
|
||||
// Must be called with b.mu held.
|
||||
func (b *LocalBackend) configureCustomMullvadLocked(ctx context.Context, accountNumber string) error {
|
||||
if !mullvad.CustomMullvadEnabled() {
|
||||
b.logf("mullvad: feature not enabled (TS_ENABLE_CUSTOM_MULLVAD not set)")
|
||||
return mullvad.ErrNotEnabled
|
||||
}
|
||||
|
||||
// If account number is empty, clear the configuration
|
||||
if accountNumber == "" {
|
||||
b.clearCustomMullvadLocked()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get or initialize custom mullvad state
|
||||
state := b.getOrCreateCustomMullvadStateLocked()
|
||||
|
||||
// Check if we need to create a new client
|
||||
if state.client == nil || state.client.AccountNumber() != accountNumber {
|
||||
// Create a DNS resolver that bypasses MagicDNS for Mullvad API calls.
|
||||
// This solves the bootstrap problem where MagicDNS routes to Mullvad DNS
|
||||
// which is unreachable before the Mullvad tunnel is established.
|
||||
dnsResolver := b.createMullvadBootstrapDNSResolver()
|
||||
|
||||
// Use SystemDial to bypass Tailscale tunnel for Mullvad API calls.
|
||||
// This is necessary because the exit node might be a Mullvad server,
|
||||
// and we need to reach the Mullvad API directly.
|
||||
client, err := mullvad.NewClient(accountNumber, b.logf, b.dialer.SystemDial, dnsResolver)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating Mullvad client: %w", err)
|
||||
}
|
||||
state.client = client
|
||||
}
|
||||
|
||||
// Authenticate with Mullvad
|
||||
if err := state.client.Authenticate(ctx); err != nil {
|
||||
b.setCustomMullvadAuthFailedWarning(err)
|
||||
return fmt.Errorf("authenticating with Mullvad: %w", err)
|
||||
}
|
||||
|
||||
// Clear any previous auth failure warning
|
||||
b.clearCustomMullvadAuthFailedWarning()
|
||||
|
||||
// Check account status
|
||||
status, err := state.client.GetAccountStatus(ctx)
|
||||
if err != nil {
|
||||
b.setCustomMullvadAuthFailedWarning(err)
|
||||
return fmt.Errorf("checking account status: %w", err)
|
||||
}
|
||||
state.accountStatus = status
|
||||
|
||||
if status.IsExpired {
|
||||
b.setCustomMullvadExpiredWarning(status)
|
||||
return mullvad.ErrAccountExpired
|
||||
}
|
||||
|
||||
// Update health warnings based on account status
|
||||
b.updateCustomMullvadHealthWarningsLocked(status)
|
||||
|
||||
// Ensure we have a dedicated WireGuard key for Mullvad
|
||||
if err := b.ensureCustomMullvadKeyLocked(ctx, state); err != nil {
|
||||
return fmt.Errorf("ensuring Mullvad key: %w", err)
|
||||
}
|
||||
|
||||
// Fetch server list
|
||||
servers, err := state.client.GetServers(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetching Mullvad servers: %w", err)
|
||||
}
|
||||
|
||||
// Convert servers to tailcfg.Node entries
|
||||
peers := make([]*tailcfg.Node, 0, len(servers))
|
||||
for _, server := range servers {
|
||||
if !server.Active {
|
||||
continue
|
||||
}
|
||||
node := b.customMullvadServerToNode(server, state.deviceInfo)
|
||||
peers = append(peers, node)
|
||||
}
|
||||
|
||||
state.peers = peers
|
||||
state.lastRefresh = time.Now()
|
||||
|
||||
// Inject peers into nodeBackend for proper routing and WireGuard config
|
||||
b.injectCustomMullvadPeersLocked()
|
||||
|
||||
b.logf("mullvad: configured %d exit nodes from personal account", len(peers))
|
||||
return nil
|
||||
}
|
||||
|
||||
// getOrCreateCustomMullvadStateLocked returns the custom Mullvad state,
|
||||
// creating it if necessary. Must be called with b.mu held.
|
||||
func (b *LocalBackend) getOrCreateCustomMullvadStateLocked() *customMullvadState {
|
||||
if b.customMullvadState == nil {
|
||||
b.customMullvadState = &customMullvadState{}
|
||||
}
|
||||
return b.customMullvadState
|
||||
}
|
||||
|
||||
// clearCustomMullvadLocked removes the custom Mullvad configuration.
|
||||
// Must be called with b.mu held.
|
||||
func (b *LocalBackend) clearCustomMullvadLocked() {
|
||||
if b.customMullvadState == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Clear health warnings
|
||||
b.clearCustomMullvadWarnings()
|
||||
|
||||
// Clear custom peers from nodeBackend
|
||||
b.currentNode().SetCustomPeers(nil)
|
||||
|
||||
b.customMullvadState = nil
|
||||
b.logf("mullvad: cleared custom Mullvad configuration")
|
||||
}
|
||||
|
||||
// ensureCustomMullvadKeyLocked registers the Tailscale node's public key with Mullvad.
|
||||
// We use the Tailscale node key (not a separate key) because wireguard-go only supports
|
||||
// a single private key per device. The Tailscale key is what wireguard-go will use
|
||||
// for all WireGuard connections, including to Mullvad servers.
|
||||
// Must be called with b.mu held.
|
||||
func (b *LocalBackend) ensureCustomMullvadKeyLocked(ctx context.Context, state *customMullvadState) error {
|
||||
// Get the current Tailscale node key - this is what wireguard-go uses
|
||||
priv := b.pm.CurrentPrefs().Persist().PrivateNodeKey()
|
||||
if priv.IsZero() {
|
||||
return fmt.Errorf("tailscale node key not available; login required")
|
||||
}
|
||||
|
||||
pubKey := priv.Public()
|
||||
|
||||
// Check if we have a previously registered device ID
|
||||
existingDeviceID, _ := b.loadMullvadDeviceIDLocked()
|
||||
|
||||
// Register/lookup device with Mullvad using Tailscale's public key
|
||||
deviceInfo, err := state.client.RegisterDevice(ctx, pubKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("registering with Mullvad: %w", err)
|
||||
}
|
||||
state.deviceInfo = deviceInfo
|
||||
|
||||
// Save device ID for future lookups (key changes trigger re-registration)
|
||||
if existingDeviceID != deviceInfo.ID {
|
||||
if err := b.saveMullvadDeviceIDLocked(deviceInfo.ID); err != nil {
|
||||
b.logf("mullvad: warning: failed to save device ID: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
b.logf("mullvad: registered Tailscale key with Mullvad (device: %s)", deviceInfo.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// mullvadDeviceIDStateKey is the state key for storing the custom Mullvad device ID.
|
||||
const mullvadDeviceIDStateKey = "custom-mullvad-device-id"
|
||||
|
||||
// saveMullvadDeviceIDLocked saves the Mullvad device ID to storage.
|
||||
// Must be called with b.mu held.
|
||||
func (b *LocalBackend) saveMullvadDeviceIDLocked(deviceID string) error {
|
||||
if err := b.pm.WriteState(mullvadDeviceIDStateKey, []byte(deviceID)); err != nil {
|
||||
return fmt.Errorf("writing device ID: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadMullvadDeviceIDLocked loads the Mullvad device ID from storage.
|
||||
// Must be called with b.mu held.
|
||||
func (b *LocalBackend) loadMullvadDeviceIDLocked() (string, error) {
|
||||
deviceID, err := b.pm.Store().ReadState(mullvadDeviceIDStateKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(deviceID), nil
|
||||
}
|
||||
|
||||
// customMullvadServerToNode converts a Mullvad server to a tailcfg.Node.
|
||||
func (b *LocalBackend) customMullvadServerToNode(server mullvad.Server, deviceInfo *mullvad.DeviceInfo) *tailcfg.Node {
|
||||
// Generate a stable node ID from the hostname
|
||||
h := sha256.Sum256([]byte("custom-mullvad-" + server.Hostname))
|
||||
nodeID := tailcfg.NodeID(binary.BigEndian.Uint64(h[:8]))
|
||||
|
||||
endpoints := make([]netip.AddrPort, 0, 2)
|
||||
if server.IPv4.IsValid() {
|
||||
endpoints = append(endpoints, netip.AddrPortFrom(server.IPv4, server.Port))
|
||||
}
|
||||
if server.IPv6.IsValid() {
|
||||
endpoints = append(endpoints, netip.AddrPortFrom(server.IPv6, server.Port))
|
||||
}
|
||||
|
||||
// Use the server's public IPs as addresses for display purposes.
|
||||
// This gives each server a unique IP in the exit-node list.
|
||||
addresses := make([]netip.Prefix, 0, 2)
|
||||
if server.IPv4.IsValid() {
|
||||
addresses = append(addresses, netip.PrefixFrom(server.IPv4, 32))
|
||||
}
|
||||
if server.IPv6.IsValid() {
|
||||
addresses = append(addresses, netip.PrefixFrom(server.IPv6, 128))
|
||||
}
|
||||
|
||||
// AllowedIPs for exit node: 0.0.0.0/0 and ::/0
|
||||
allowedIPs := []netip.Prefix{
|
||||
netip.MustParsePrefix("0.0.0.0/0"),
|
||||
netip.MustParsePrefix("::/0"),
|
||||
}
|
||||
|
||||
// Mullvad DNS servers
|
||||
dnsResolvers := []*dnstype.Resolver{
|
||||
{Addr: "10.64.0.1"}, // Mullvad internal DNS
|
||||
}
|
||||
|
||||
return &tailcfg.Node{
|
||||
ID: nodeID,
|
||||
StableID: tailcfg.StableNodeID("custom-mullvad-" + server.Hostname),
|
||||
Name: server.Hostname + ".mullvad.custom.",
|
||||
|
||||
Key: server.PublicKey,
|
||||
IsWireGuardOnly: true,
|
||||
|
||||
Endpoints: endpoints,
|
||||
AllowedIPs: allowedIPs,
|
||||
Addresses: addresses,
|
||||
|
||||
Hostinfo: (&tailcfg.Hostinfo{
|
||||
Location: &tailcfg.Location{
|
||||
Country: server.CountryName,
|
||||
CountryCode: strings.ToUpper(server.CountryCode),
|
||||
City: server.CityName,
|
||||
CityCode: strings.ToUpper(server.CityCode),
|
||||
Priority: 50, // Lower than Tailscale-managed Mullvad
|
||||
},
|
||||
}).View(),
|
||||
|
||||
ExitNodeDNSResolvers: dnsResolvers,
|
||||
|
||||
CapMap: tailcfg.NodeCapMap{
|
||||
tailcfg.NodeAttrSuggestExitNode: nil,
|
||||
tailcfg.NodeAttrCustomMullvad: nil,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// getCustomMullvadPeers returns the custom Mullvad peers as NodeViews.
|
||||
// Thread-safe.
|
||||
func (b *LocalBackend) getCustomMullvadPeers() []tailcfg.NodeView {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
return b.getCustomMullvadPeersLocked()
|
||||
}
|
||||
|
||||
// getCustomMullvadPeersLocked returns the custom Mullvad peers as NodeViews.
|
||||
// Must be called with b.mu held.
|
||||
func (b *LocalBackend) getCustomMullvadPeersLocked() []tailcfg.NodeView {
|
||||
if b.customMullvadState == nil || len(b.customMullvadState.peers) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
peers := make([]tailcfg.NodeView, len(b.customMullvadState.peers))
|
||||
for i, p := range b.customMullvadState.peers {
|
||||
peers[i] = p.View()
|
||||
}
|
||||
return peers
|
||||
}
|
||||
|
||||
// isCustomMullvadNode reports whether the node is a custom Mullvad node.
|
||||
func isCustomMullvadNode(node tailcfg.NodeView) bool {
|
||||
if !node.Valid() {
|
||||
return false
|
||||
}
|
||||
return node.CapMap().Contains(tailcfg.NodeAttrCustomMullvad)
|
||||
}
|
||||
|
||||
// isCustomMullvadNodeByStableID reports whether the StableNodeID is a custom Mullvad node.
|
||||
func isCustomMullvadNodeByStableID(id tailcfg.StableNodeID) bool {
|
||||
return strings.HasPrefix(string(id), "custom-mullvad-")
|
||||
}
|
||||
|
||||
// Health warning helpers
|
||||
|
||||
func (b *LocalBackend) updateCustomMullvadHealthWarningsLocked(status *mullvad.AccountStatus) {
|
||||
if status == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Clear existing warnings first
|
||||
b.clearCustomMullvadWarnings()
|
||||
|
||||
if status.IsExpired {
|
||||
b.setCustomMullvadExpiredWarning(status)
|
||||
} else if status.DaysLeft <= 7 {
|
||||
b.setCustomMullvadExpiringWarning(status)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *LocalBackend) setCustomMullvadExpiringWarning(status *mullvad.AccountStatus) {
|
||||
b.health.SetUnhealthy(health.CustomMullvadExpiringWarnable, health.Args{
|
||||
health.ArgDaysRemaining: strconv.Itoa(status.DaysLeft),
|
||||
})
|
||||
}
|
||||
|
||||
func (b *LocalBackend) setCustomMullvadExpiredWarning(status *mullvad.AccountStatus) {
|
||||
b.health.SetUnhealthy(health.CustomMullvadExpiredWarnable, health.Args{
|
||||
health.ArgExpiryDate: status.Expiry.Format("2006-01-02"),
|
||||
})
|
||||
}
|
||||
|
||||
func (b *LocalBackend) setCustomMullvadAuthFailedWarning(err error) {
|
||||
b.health.SetUnhealthy(health.CustomMullvadAuthFailedWarnable, health.Args{
|
||||
health.ArgError: err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
func (b *LocalBackend) clearCustomMullvadAuthFailedWarning() {
|
||||
b.health.SetHealthy(health.CustomMullvadAuthFailedWarnable)
|
||||
}
|
||||
|
||||
func (b *LocalBackend) clearCustomMullvadWarnings() {
|
||||
b.health.SetHealthy(health.CustomMullvadExpiringWarnable)
|
||||
b.health.SetHealthy(health.CustomMullvadExpiredWarnable)
|
||||
b.health.SetHealthy(health.CustomMullvadAuthFailedWarnable)
|
||||
}
|
||||
|
||||
// refreshCustomMullvadLocked re-fetches the Mullvad server list.
|
||||
// Must be called with b.mu held.
|
||||
func (b *LocalBackend) refreshCustomMullvadLocked(ctx context.Context) error {
|
||||
if b.customMullvadState == nil || b.customMullvadState.client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
state := b.customMullvadState
|
||||
|
||||
// Check account status
|
||||
status, err := state.client.GetAccountStatus(ctx)
|
||||
if err != nil {
|
||||
b.setCustomMullvadAuthFailedWarning(err)
|
||||
return fmt.Errorf("checking account status: %w", err)
|
||||
}
|
||||
state.accountStatus = status
|
||||
b.clearCustomMullvadAuthFailedWarning()
|
||||
|
||||
// Update health warnings
|
||||
b.updateCustomMullvadHealthWarningsLocked(status)
|
||||
|
||||
if status.IsExpired {
|
||||
return mullvad.ErrAccountExpired
|
||||
}
|
||||
|
||||
// Fetch updated server list
|
||||
servers, err := state.client.GetServers(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetching Mullvad servers: %w", err)
|
||||
}
|
||||
|
||||
// Convert servers to tailcfg.Node entries
|
||||
peers := make([]*tailcfg.Node, 0, len(servers))
|
||||
for _, server := range servers {
|
||||
if !server.Active {
|
||||
continue
|
||||
}
|
||||
node := b.customMullvadServerToNode(server, state.deviceInfo)
|
||||
peers = append(peers, node)
|
||||
}
|
||||
|
||||
state.peers = peers
|
||||
state.lastRefresh = time.Now()
|
||||
|
||||
// Inject updated peers into nodeBackend
|
||||
b.injectCustomMullvadPeersLocked()
|
||||
|
||||
b.logf("mullvad: refreshed server list, %d exit nodes available", len(peers))
|
||||
return nil
|
||||
}
|
||||
|
||||
// CustomMullvadStatus returns information about the custom Mullvad configuration.
|
||||
type CustomMullvadStatus struct {
|
||||
Configured bool
|
||||
AccountExpiry time.Time
|
||||
DaysRemaining int
|
||||
ServerCount int
|
||||
DeviceIPv4 netip.Addr
|
||||
DeviceIPv6 netip.Addr
|
||||
LastRefresh time.Time
|
||||
}
|
||||
|
||||
// GetCustomMullvadStatus returns the current custom Mullvad status.
|
||||
// Thread-safe.
|
||||
func (b *LocalBackend) GetCustomMullvadStatus() CustomMullvadStatus {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
if b.customMullvadState == nil || b.customMullvadState.client == nil {
|
||||
return CustomMullvadStatus{}
|
||||
}
|
||||
|
||||
state := b.customMullvadState
|
||||
status := CustomMullvadStatus{
|
||||
Configured: true,
|
||||
ServerCount: len(state.peers),
|
||||
LastRefresh: state.lastRefresh,
|
||||
}
|
||||
|
||||
if state.accountStatus != nil {
|
||||
status.AccountExpiry = state.accountStatus.Expiry
|
||||
status.DaysRemaining = state.accountStatus.DaysLeft
|
||||
}
|
||||
|
||||
if state.deviceInfo != nil {
|
||||
status.DeviceIPv4 = state.deviceInfo.IPv4Address
|
||||
status.DeviceIPv6 = state.deviceInfo.IPv6Address
|
||||
}
|
||||
|
||||
return status
|
||||
}
|
||||
|
||||
// injectCustomMullvadPeers adds custom Mullvad peers to the peer list.
|
||||
// This is used during netmap processing.
|
||||
func injectCustomMullvadPeers(peers []tailcfg.NodeView, customPeers []tailcfg.NodeView) []tailcfg.NodeView {
|
||||
if len(customPeers) == 0 {
|
||||
return peers
|
||||
}
|
||||
result := make([]tailcfg.NodeView, 0, len(peers)+len(customPeers))
|
||||
result = append(result, peers...)
|
||||
result = append(result, customPeers...)
|
||||
return result
|
||||
}
|
||||
|
||||
// filterCustomMullvadPeers returns only the custom Mullvad peers from a peer list.
|
||||
func filterCustomMullvadPeers(peers views.Slice[tailcfg.NodeView]) []tailcfg.NodeView {
|
||||
var result []tailcfg.NodeView
|
||||
for i := range peers.Len() {
|
||||
if isCustomMullvadNode(peers.At(i)) {
|
||||
result = append(result, peers.At(i))
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ConfigureCustomMullvad configures the custom Mullvad account.
|
||||
// This is the public method called by the local API.
|
||||
func (b *LocalBackend) ConfigureCustomMullvad(ctx context.Context, accountNumber string) error {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
return b.configureCustomMullvadLocked(ctx, accountNumber)
|
||||
}
|
||||
|
||||
// RefreshCustomMullvad refreshes the custom Mullvad configuration.
|
||||
// This is the public method called by the local API.
|
||||
func (b *LocalBackend) RefreshCustomMullvad(ctx context.Context) error {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
return b.refreshCustomMullvadLocked(ctx)
|
||||
}
|
||||
|
||||
// injectCustomMullvadPeersLocked updates the nodeBackend with custom Mullvad peers
|
||||
// so they are included in peer lookups and WireGuard configuration.
|
||||
// Must be called with b.mu held.
|
||||
func (b *LocalBackend) injectCustomMullvadPeersLocked() {
|
||||
if b.customMullvadState == nil || len(b.customMullvadState.peers) == 0 {
|
||||
b.currentNode().SetCustomPeers(nil)
|
||||
return
|
||||
}
|
||||
|
||||
peerViews := make([]tailcfg.NodeView, len(b.customMullvadState.peers))
|
||||
for i, p := range b.customMullvadState.peers {
|
||||
peerViews[i] = p.View()
|
||||
}
|
||||
b.currentNode().SetCustomPeers(peerViews)
|
||||
}
|
||||
|
||||
// createMullvadBootstrapDNSResolver creates a DNS resolver for Mullvad API calls
|
||||
// that bypasses MagicDNS by using public DoH servers as a fallback.
|
||||
// This solves the bootstrap problem where MagicDNS routes to Mullvad DNS
|
||||
// which is unreachable before the Mullvad tunnel is established.
|
||||
//
|
||||
// Pattern: Same as net/dns/resolver/forwarder.go getKnownDoHClientForProvider()
|
||||
func (b *LocalBackend) createMullvadBootstrapDNSResolver() *dnscache.Resolver {
|
||||
return &dnscache.Resolver{
|
||||
Forward: dnscache.Get().Forward,
|
||||
UseLastGood: true,
|
||||
LookupIPFallback: b.resolveMullvadAPIViaDoH,
|
||||
Logf: b.logf,
|
||||
}
|
||||
}
|
||||
|
||||
// resolveMullvadAPIViaDoH resolves hostnames using public DoH providers.
|
||||
// Uses the existing publicdns package to get known DoH provider IPs,
|
||||
// avoiding hardcoded DNS servers in this code.
|
||||
// Uses SystemDial to ensure the DoH request bypasses the Tailscale tunnel.
|
||||
func (b *LocalBackend) resolveMullvadAPIViaDoH(ctx context.Context, host string) ([]netip.Addr, error) {
|
||||
// Try multiple DoH providers in order (same providers used by net/dns/resolver/forwarder.go)
|
||||
dohProviders := []string{
|
||||
"https://cloudflare-dns.com/dns-query",
|
||||
"https://dns.google/dns-query",
|
||||
"https://dns.quad9.net/dns-query",
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for _, dohBase := range dohProviders {
|
||||
addrs, err := b.resolveViaDoHProvider(ctx, host, dohBase)
|
||||
if err == nil && len(addrs) > 0 {
|
||||
return addrs, nil
|
||||
}
|
||||
lastErr = err
|
||||
b.logf("mullvad: DoH provider %s failed: %v", dohBase, err)
|
||||
}
|
||||
|
||||
if lastErr != nil {
|
||||
return nil, fmt.Errorf("all DoH providers failed, last error: %w", lastErr)
|
||||
}
|
||||
return nil, errors.New("no DoH providers available")
|
||||
}
|
||||
|
||||
// resolveViaDoHProvider resolves a hostname using a specific DoH provider.
|
||||
// Uses publicdns.DoHIPsOfBase() to get known IPs for the provider.
|
||||
func (b *LocalBackend) resolveViaDoHProvider(ctx context.Context, host, dohBase string) ([]netip.Addr, error) {
|
||||
// Get known IPs for this DoH provider from publicdns package.
|
||||
// This avoids hardcoding IPs - we use Tailscale's existing known DoH provider list.
|
||||
allIPs := publicdns.DoHIPsOfBase(dohBase)
|
||||
if len(allIPs) == 0 {
|
||||
return nil, fmt.Errorf("no known IPs for DoH provider %s", dohBase)
|
||||
}
|
||||
|
||||
// Parse the DoH URL to get the hostname
|
||||
dohURL, err := url.Parse(dohBase)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create HTTP client that dials the DoH provider directly by IP.
|
||||
// Pattern from net/dns/resolver/forwarder.go:438-442
|
||||
dohResolver := &dnscache.Resolver{
|
||||
SingleHost: dohURL.Hostname(),
|
||||
SingleHostStaticResult: allIPs,
|
||||
}
|
||||
tr := http.DefaultTransport.(*http.Transport).Clone()
|
||||
tr.DialContext = dnscache.Dialer(b.dialer.SystemDial, dohResolver)
|
||||
hc := &http.Client{Transport: tr, Timeout: 10 * time.Second}
|
||||
|
||||
// Build DNS query for host (A and AAAA records)
|
||||
addrs, err := b.doDoHQuery(ctx, hc, dohBase, host, dnsmessage.TypeA)
|
||||
if err != nil {
|
||||
b.logf("mullvad: DoH A query to %s for %s failed: %v", dohBase, host, err)
|
||||
}
|
||||
|
||||
// Also try AAAA
|
||||
addrs6, err6 := b.doDoHQuery(ctx, hc, dohBase, host, dnsmessage.TypeAAAA)
|
||||
if err6 == nil {
|
||||
addrs = append(addrs, addrs6...)
|
||||
}
|
||||
|
||||
if len(addrs) == 0 {
|
||||
return nil, fmt.Errorf("no addresses returned from %s for %s", dohBase, host)
|
||||
}
|
||||
return addrs, nil
|
||||
}
|
||||
|
||||
// doDoHQuery performs a single DoH query for the given record type.
|
||||
func (b *LocalBackend) doDoHQuery(ctx context.Context, hc *http.Client, dohBase, host string, qtype dnsmessage.Type) ([]netip.Addr, error) {
|
||||
// Build DNS query packet
|
||||
var msg dnsmessage.Message
|
||||
msg.Header.ID = uint16(time.Now().UnixNano())
|
||||
msg.Header.RecursionDesired = true
|
||||
msg.Questions = []dnsmessage.Question{
|
||||
{Name: dnsmessage.MustNewName(host + "."), Type: qtype, Class: dnsmessage.ClassINET},
|
||||
}
|
||||
packet, err := msg.Pack()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Send DoH request (RFC 8484)
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", dohBase, bytes.NewReader(packet))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/dns-message")
|
||||
req.Header.Set("Accept", "application/dns-message")
|
||||
|
||||
resp, err := hc.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("DoH returned status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse DNS response
|
||||
var respMsg dnsmessage.Message
|
||||
if err := respMsg.Unpack(body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var addrs []netip.Addr
|
||||
for _, ans := range respMsg.Answers {
|
||||
switch r := ans.Body.(type) {
|
||||
case *dnsmessage.AResource:
|
||||
addrs = append(addrs, netip.AddrFrom4(r.A))
|
||||
case *dnsmessage.AAAAResource:
|
||||
addrs = append(addrs, netip.AddrFrom16(r.AAAA))
|
||||
}
|
||||
}
|
||||
return addrs, nil
|
||||
}
|
||||
@ -0,0 +1,167 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package localapi
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"tailscale.com/ipn/ipnlocal/mullvad"
|
||||
"tailscale.com/util/httpm"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Register Mullvad endpoints unconditionally.
|
||||
// The feature flag is checked at runtime in each handler.
|
||||
Register("mullvad/status", (*Handler).serveMullvadStatus)
|
||||
Register("mullvad/configure", (*Handler).serveMullvadConfigure)
|
||||
Register("mullvad/refresh", (*Handler).serveMullvadRefresh)
|
||||
}
|
||||
|
||||
// MullvadStatusResponse is the response for the mullvad/status endpoint.
|
||||
type MullvadStatusResponse struct {
|
||||
Configured bool `json:"configured"`
|
||||
AccountExpiry time.Time `json:"accountExpiry,omitempty"`
|
||||
DaysRemaining int `json:"daysRemaining,omitempty"`
|
||||
ServerCount int `json:"serverCount,omitempty"`
|
||||
DeviceIPv4 string `json:"deviceIPv4,omitempty"`
|
||||
DeviceIPv6 string `json:"deviceIPv6,omitempty"`
|
||||
LastRefresh time.Time `json:"lastRefresh,omitempty"`
|
||||
}
|
||||
|
||||
// serveMullvadStatus returns the current custom Mullvad configuration status.
|
||||
func (h *Handler) serveMullvadStatus(w http.ResponseWriter, r *http.Request) {
|
||||
if !mullvad.CustomMullvadEnabled() {
|
||||
http.Error(w, "custom Mullvad support not enabled", http.StatusNotImplemented)
|
||||
return
|
||||
}
|
||||
if r.Method != httpm.GET {
|
||||
http.Error(w, "only GET allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if !h.PermitRead {
|
||||
http.Error(w, "read access denied", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
status := h.b.GetCustomMullvadStatus()
|
||||
resp := MullvadStatusResponse{
|
||||
Configured: status.Configured,
|
||||
AccountExpiry: status.AccountExpiry,
|
||||
DaysRemaining: status.DaysRemaining,
|
||||
ServerCount: status.ServerCount,
|
||||
LastRefresh: status.LastRefresh,
|
||||
}
|
||||
if status.DeviceIPv4.IsValid() {
|
||||
resp.DeviceIPv4 = status.DeviceIPv4.String()
|
||||
}
|
||||
if status.DeviceIPv6.IsValid() {
|
||||
resp.DeviceIPv6 = status.DeviceIPv6.String()
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
|
||||
// MullvadConfigureRequest is the request for the mullvad/configure endpoint.
|
||||
type MullvadConfigureRequest struct {
|
||||
AccountNumber string `json:"accountNumber"`
|
||||
}
|
||||
|
||||
// MullvadConfigureResponse is the response for the mullvad/configure endpoint.
|
||||
type MullvadConfigureResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Error string `json:"error,omitempty"`
|
||||
AccountExpiry time.Time `json:"accountExpiry,omitempty"`
|
||||
ServerCount int `json:"serverCount,omitempty"`
|
||||
}
|
||||
|
||||
// serveMullvadConfigure configures the custom Mullvad account.
|
||||
func (h *Handler) serveMullvadConfigure(w http.ResponseWriter, r *http.Request) {
|
||||
if !mullvad.CustomMullvadEnabled() {
|
||||
http.Error(w, "custom Mullvad support not enabled", http.StatusNotImplemented)
|
||||
return
|
||||
}
|
||||
if r.Method != httpm.POST {
|
||||
http.Error(w, "only POST allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if !h.PermitWrite {
|
||||
http.Error(w, "write access denied", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
var req MullvadConfigureRequest
|
||||
if err := json.NewDecoder(io.LimitReader(r.Body, 1<<16)).Decode(&req); err != nil { // 64KB limit
|
||||
http.Error(w, "invalid request body: "+err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Use LocalBackend to configure Mullvad
|
||||
// This is done via preferences to ensure persistence
|
||||
if err := h.b.ConfigureCustomMullvad(r.Context(), req.AccountNumber); err != nil {
|
||||
switch {
|
||||
case errors.Is(err, mullvad.ErrInvalidAccount):
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
case errors.Is(err, mullvad.ErrAccountExpired):
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
case errors.Is(err, mullvad.ErrNotEnabled):
|
||||
http.Error(w, err.Error(), http.StatusNotImplemented)
|
||||
default:
|
||||
WriteErrorJSON(w, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
status := h.b.GetCustomMullvadStatus()
|
||||
resp := MullvadConfigureResponse{
|
||||
Success: true,
|
||||
AccountExpiry: status.AccountExpiry,
|
||||
ServerCount: status.ServerCount,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
|
||||
// serveMullvadRefresh forces a refresh of the Mullvad server list.
|
||||
func (h *Handler) serveMullvadRefresh(w http.ResponseWriter, r *http.Request) {
|
||||
if !mullvad.CustomMullvadEnabled() {
|
||||
http.Error(w, "custom Mullvad support not enabled", http.StatusNotImplemented)
|
||||
return
|
||||
}
|
||||
if r.Method != httpm.POST {
|
||||
http.Error(w, "only POST allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if !h.PermitWrite {
|
||||
http.Error(w, "write access denied", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.b.RefreshCustomMullvad(r.Context()); err != nil {
|
||||
WriteErrorJSON(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
status := h.b.GetCustomMullvadStatus()
|
||||
resp := MullvadStatusResponse{
|
||||
Configured: status.Configured,
|
||||
AccountExpiry: status.AccountExpiry,
|
||||
DaysRemaining: status.DaysRemaining,
|
||||
ServerCount: status.ServerCount,
|
||||
LastRefresh: status.LastRefresh,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
|
||||
// Ensure ipnlocal.LocalBackend has the necessary methods.
|
||||
// These are defined in mullvad_integration.go.
|
||||
// Note: We use context.Context from the standard library, not an interface type.
|
||||
// This is just a compile-time check that the methods exist.
|
||||
Loading…
Reference in New Issue