diff --git a/Makefile b/Makefile index 472d70e..64a039e 100644 --- a/Makefile +++ b/Makefile @@ -147,7 +147,7 @@ $(LIBTAILSCALE): Makefile android/libs $(LIBTAILSCALE_SOURCES) $(GOBIN)/gomobile -ldflags "$(FULL_LDFLAGS)" \ -o $@ ./libtailscale -libtailscale: $(LIBTAILSCALE) +libtailscale: $(LIBTAILSCALE) ## Build libtailscale ANDROID_SOURCES=$(shell find android -type f -not -path "android/build/*" -not -path '*/.*') DEBUG_INTERMEDIARY = android/build/outputs/apk/debug/android-debug.apk diff --git a/android/src/main/java/com/tailscale/ipn/VPNServiceBuilder.kt b/android/src/main/java/com/tailscale/ipn/VPNServiceBuilder.kt index 25abb52..eebbe10 100644 --- a/android/src/main/java/com/tailscale/ipn/VPNServiceBuilder.kt +++ b/android/src/main/java/com/tailscale/ipn/VPNServiceBuilder.kt @@ -5,6 +5,8 @@ package com.tailscale.ipn import android.net.VpnService import libtailscale.ParcelFileDescriptor +import java.net.InetAddress +import android.net.IpPrefix as AndroidIpPrefix class VPNServiceBuilder(private val builder: VpnService.Builder) : libtailscale.VPNServiceBuilder { override fun addAddress(p0: String, p1: Int) { @@ -19,6 +21,12 @@ class VPNServiceBuilder(private val builder: VpnService.Builder) : libtailscale. builder.addRoute(p0, p1) } + override fun excludeRoute(p0: String, p1: Int) { + val inetAddress = InetAddress.getByName(p0) + val prefix = AndroidIpPrefix(inetAddress, p1) + builder.excludeRoute(prefix) + } + override fun addSearchDomain(p0: String) { builder.addSearchDomain(p0) } diff --git a/libtailscale/backend.go b/libtailscale/backend.go index 0cd9cb7..73602c8 100644 --- a/libtailscale/backend.go +++ b/libtailscale/backend.go @@ -5,6 +5,7 @@ package libtailscale import ( "context" + "errors" "fmt" "log" "net/http" @@ -166,6 +167,27 @@ func (a *App) runBackend(ctx context.Context) error { select { case s := <-stateCh: state = s + if cfg.rcfg != nil && state >= ipn.Starting && service != nil { + // On state change, check if there are router or config changes requiring an update to VPNBuilder + if err := b.updateTUN(service, cfg.rcfg, cfg.dcfg); err != nil { + if errors.Is(err, errMultipleUsers) { + // TODO: surface error to user + } + log.Printf("VPN update failed: %v", err) + + mp := new(ipn.MaskedPrefs) + mp.WantRunning = false + mp.WantRunningSet = true + + _, err := a.EditPrefs(*mp) + if err != nil { + log.Printf("localapi edit prefs error %v", err) + } + + b.lastCfg = nil + b.CloseTUNs() + } + } case n := <-netmapCh: networkMap = n case c := <-configs: @@ -214,6 +236,8 @@ func (a *App) runBackend(ctx context.Context) error { if err := b.updateTUN(service, cfg.rcfg, cfg.dcfg); err != nil { log.Printf("VPN update failed: %v", err) service.Close() + b.lastCfg = nil + b.CloseTUNs() } } case s := <-onDisconnect: diff --git a/libtailscale/callbacks.go b/libtailscale/callbacks.go index 47b7931..e79a173 100644 --- a/libtailscale/callbacks.go +++ b/libtailscale/callbacks.go @@ -8,11 +8,6 @@ import ( ) var ( - // onVPNPrepared is notified when VpnService.prepare succeeds. - onVPNPrepared = make(chan struct{}, 1) - // onVPNRevoked is notified whenever the VPN service is revoked. - onVPNRevoked = make(chan struct{}, 1) - // onVPNRequested receives global IPNService references when // a VPN connection is requested. onVPNRequested = make(chan IPNService) @@ -35,20 +30,6 @@ func OnDNSConfigChanged(ifname string) { } } -func notifyVPNPrepared() { - select { - case onVPNPrepared <- struct{}{}: - default: - } -} - -func notifyVPNRevoked() { - select { - case onVPNRevoked <- struct{}{}: - default: - } -} - var android struct { // mu protects all fields of this structure. However, once a // non-nil jvm is returned from javaVM, all the other fields may diff --git a/libtailscale/interfaces.go b/libtailscale/interfaces.go index d7f3227..317d05e 100644 --- a/libtailscale/interfaces.go +++ b/libtailscale/interfaces.go @@ -81,6 +81,7 @@ type VPNServiceBuilder interface { AddDNSServer(string) error AddSearchDomain(string) error AddRoute(string, int32) error + ExcludeRoute(string, int32) error AddAddress(string, int32) error Establish() (ParcelFileDescriptor, error) } diff --git a/libtailscale/localapi.go b/libtailscale/localapi.go index 974c8e5..bed54b7 100644 --- a/libtailscale/localapi.go +++ b/libtailscale/localapi.go @@ -5,6 +5,7 @@ package libtailscale import ( "context" + "encoding/json" "fmt" "io" "log" @@ -18,9 +19,12 @@ import ( "strings" "sync" "time" + + "tailscale.com/ipn" ) -// CallLocalAPI calls the given endpoint on the local API using the given HTTP method +// CallLocalAPI is the method for making localapi calls from Kotlin. It calls +// the given endpoint on the local API using the given HTTP method and // optionally sending the given body. It returns a Response representing the // result of the call and an error if the call could not be completed or the // local API returned a status code in the 400 series or greater. @@ -103,6 +107,18 @@ func (app *App) CallLocalAPIMultipart(timeoutMillis int, method, endpoint string } } +func (app *App) EditPrefs(prefs ipn.MaskedPrefs) (LocalAPIResponse, error) { + r, w := io.Pipe() + go func() { + defer w.Close() + enc := json.NewEncoder(w) + if err := enc.Encode(prefs); err != nil { + log.Printf("Error encoding preferences: %v", err) + } + }() + return app.callLocalAPI(30000, "PATCH", "prefs", nil, r) +} + func (app *App) callLocalAPI(timeoutMillis int, method, endpoint string, header http.Header, body io.ReadCloser) (LocalAPIResponse, error) { defer func() { if p := recover(); p != nil { diff --git a/libtailscale/net.go b/libtailscale/net.go index f90e65b..4165b0b 100644 --- a/libtailscale/net.go +++ b/libtailscale/net.go @@ -170,6 +170,17 @@ func (b *backend) updateTUN(service IPNService, rcfg *router.Config, dcfg *dns.O return err } } + + for _, route := range rcfg.LocalRoutes { + addr := route.Addr() + if addr.IsLoopback() { + continue // Skip the loopback addresses since VpnService throws an exception for those (both IPv4 and IPv6) - see https://android.googlesource.com/platform/frameworks/base/+/c741553/core/java/android/net/VpnService.java#303 + } + route = route.Masked() + if err := builder.ExcludeRoute(route.Addr().String(), int32(route.Bits())); err != nil { + return err + } + } b.logger.Logf("updateTUN: added %d routes", len(rcfg.Routes)) for _, addr := range rcfg.LocalAddrs { @@ -210,12 +221,6 @@ func (b *backend) updateTUN(service IPNService, rcfg *router.Config, dcfg *dns.O b.devices.add(tunDev) b.logger.Logf("updateTUN: added TUN device") - // TODO(oxtoacart): figure out what to do with this - // if err != nil { - // b.lastCfg = nil - // b.CloseTUNs() - // return err - // } b.lastCfg = rcfg b.lastDNSCfg = dcfg return nil