diff --git a/.gitignore b/.gitignore index b5d5ae6..7f87946 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,9 @@ build android/libs android_legacy/libs +# Ignore ABI +android/src/main/jniLibs/* + # Android Studio files android_legacy/.idea android_legacy/local.properties diff --git a/Makefile b/Makefile index feb85fb..b33a243 100644 --- a/Makefile +++ b/Makefile @@ -163,8 +163,8 @@ tailscale-new-fdroid.apk: $(AAR_NEXTGEN) (cd android && ./gradlew test assembleFdroidDebug) mv android/build/outputs/apk/fdroid/debug/android-fdroid-debug.apk $@ -tailscale-new-debug.apk: $(AAR_NEXTGEN) - (cd android && ./gradlew test assemblePlayDebug) +tailscale-new-debug.apk: + (cd android && ./gradlew test buildAllGoLibs assemblePlayDebug) mv android/build/outputs/apk/play/debug/android-play-debug.apk $@ tailscale-new-debug: tailscale-new-debug.apk diff --git a/android/build.gradle b/android/build.gradle index e88214d..255d6f1 100644 --- a/android/build.gradle +++ b/android/build.gradle @@ -105,3 +105,62 @@ dependencies { // Non-free dependencies. playImplementation 'com.google.android.gms:play-services-auth:20.7.0' } + +def ndkPath = project.hasProperty('ndkPath') ? project.property('ndkPath') : System.getenv('ANDROID_SDK_ROOT') + +task checkNDK { + doFirst { + if (ndkPath == null) { + throw new GradleException('NDK path not found. Please define ndkPath in local.properties or ANDROID_SDK_HOME environment variable.') + } + } +} + +task buildGoLibArm64(type: Exec) { + inputs.dir '../pkg/tailscale' + outputs.file 'src/main/jniLibs/arm64-v8a/libtailscale.so' + environment "CC", "$ndkPath/ndk-bundle/toolchains/llvm/prebuilt/darwin-x86_64/bin/aarch64-linux-android30-clang" + commandLine 'sh', '-c', "GOOS=android GOARCH=arm64 CGO_ENABLED=1 go build -buildmode=c-shared -ldflags=-w -o src/main/jniLibs/arm64-v8a/libtailscale.so ../pkg/tailscale" +} + +task buildGoLibArmeabi(type: Exec) { + inputs.dir '../pkg/tailscale' + outputs.file 'src/main/jniLibs/armeabi-v7a/libtailscale.so' + environment "CC", "$ndkPath/ndk-bundle/toolchains/llvm/prebuilt/darwin-x86_64/bin/armv7a-linux-androideabi30-clang" + commandLine 'sh', '-c', "GOOS=android GOARCH=arm CGO_ENABLED=1 go build -buildmode=c-shared -ldflags=-w -o src/main/jniLibs/armeabi-v7a/libtailscale.so ../pkg/tailscale" +} + +task buildGoLibX86(type: Exec) { + inputs.dir '../pkg/tailscale' + outputs.file 'src/main/jniLibs/x86/libtailscale.so' + environment "CC", "$ndkPath/ndk-bundle/toolchains/llvm/prebuilt/darwin-x86_64/bin/i686-linux-android30-clang" + commandLine 'sh', '-c', "GOOS=android GOARCH=386 CGO_ENABLED=1 go build -buildmode=c-shared -ldflags=-w -o src/main/jniLibs/x86/libtailscale.so ../pkg/tailscale" +} + +task buildGoLibX86_64(type: Exec) { + inputs.dir '../pkg/tailscale' + outputs.file 'src/main/jniLibs/x86_64/libtailscale.so' + environment "CC", "$ndkPath/ndk-bundle/toolchains/llvm/prebuilt/darwin-x86_64/bin/x86_64-linux-android30-clang" + commandLine 'sh', '-c', "GOOS=android GOARCH=amd64 CGO_ENABLED=1 go build -buildmode=c-shared -ldflags=-w -o src/main/jniLibs/x86_64/libtailscale.so ../pkg/tailscale" +} + +task buildAllGoLibs { + dependsOn checkNDK, buildGoLibArm64, buildGoLibArmeabi, buildGoLibX86, buildGoLibX86_64 +} + +assemble.dependsOn buildAllGoLibs + +task cleanGoLibs(type: Delete) { + delete 'src/main/jniLibs/arm64-v8a/libtailscale.so', + 'src/main/jniLibs/armeabi-v7a/libtailscale.so', + 'src/main/jniLibs/x86/libtailscale.so', + 'src/main/jniLibs/x86_64/libtailscale.so' +} + +clean.dependsOn cleanGoLibs + +tasks.whenTaskAdded { task -> + if (task.name.startsWith('merge') && task.name.endsWith('JniLibFolders')) { + task.mustRunAfter buildAllGoLibs + } +} diff --git a/android/src/main/java/com/tailscale/ipn/App.java b/android/src/main/java/com/tailscale/ipn/App.java index 7789bb7..ed6f4b9 100644 --- a/android/src/main/java/com/tailscale/ipn/App.java +++ b/android/src/main/java/com/tailscale/ipn/App.java @@ -47,8 +47,6 @@ import com.tailscale.ipn.mdm.MDMSettings; import com.tailscale.ipn.mdm.ShowHideSetting; import com.tailscale.ipn.mdm.StringSetting; -import org.gioui.Gio; - import java.io.File; import java.io.IOException; import java.net.InetAddress; @@ -88,6 +86,8 @@ public class App extends Application { f.startActivityForResult(intent, request); } + static native void initBackend(byte[] dataDir, Context context); + static native void onVPNPrepared(); private static native void onDnsConfigChanged(); @@ -103,8 +103,17 @@ public class App extends Application { @Override public void onCreate() { super.onCreate(); - // Load and initialize the Go library. - Gio.init(this); + + System.loadLibrary("tailscale"); + + String dataDir = this.getFilesDir().getAbsolutePath(); + byte[] dataDirUTF8; + try { + dataDirUTF8 = dataDir.getBytes("UTF-8"); + initBackend(dataDirUTF8, this); + } catch (Exception e) { + android.util.Log.d("tailscale", "Error getting directory"); + } this.connectivityManager = (ConnectivityManager) this.getSystemService(Context.CONNECTIVITY_SERVICE); setAndRegisterNetworkCallbacks(); diff --git a/android/src/main/java/com/tailscale/ipn/IPNActivity.java b/android/src/main/java/com/tailscale/ipn/IPNActivity.java index bb53822..928abe6 100644 --- a/android/src/main/java/com/tailscale/ipn/IPNActivity.java +++ b/android/src/main/java/com/tailscale/ipn/IPNActivity.java @@ -12,20 +12,14 @@ import android.net.Uri; import android.os.Bundle; import android.provider.OpenableColumns; -import org.gioui.GioView; - import java.util.List; public final class IPNActivity extends Activity { final static int WRITE_STORAGE_RESULT = 1000; - private GioView view; - @Override public void onCreate(Bundle state) { super.onCreate(state); - view = new GioView(this); - setContentView(view); handleIntent(); } @@ -102,37 +96,26 @@ public final class IPNActivity extends Activity { @Override public void onDestroy() { - view.destroy(); super.onDestroy(); } @Override public void onStart() { super.onStart(); - view.start(); } @Override public void onStop() { - view.stop(); super.onStop(); } @Override public void onConfigurationChanged(Configuration c) { super.onConfigurationChanged(c); - view.configurationChanged(); } @Override public void onLowMemory() { super.onLowMemory(); - GioView.onLowMemory(); - } - - @Override - public void onBackPressed() { - if (!view.backPressed()) - super.onBackPressed(); } } diff --git a/pkg/jni/jnipkg.go b/pkg/jni/jnipkg.go new file mode 100644 index 0000000..d32db06 --- /dev/null +++ b/pkg/jni/jnipkg.go @@ -0,0 +1,506 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package jni implements various helper functions for communicating with the Android JVM +// though JNI. +package jnipkg + +import ( + "errors" + "fmt" + "reflect" + "runtime" + "sync" + "unicode/utf16" + "unsafe" +) + +/* +#cgo CFLAGS: -Wall + +#include +#include + +static jint jni_AttachCurrentThread(JavaVM *vm, JNIEnv **p_env, void *thr_args) { + return (*vm)->AttachCurrentThread(vm, p_env, thr_args); +} + +static jint jni_DetachCurrentThread(JavaVM *vm) { + return (*vm)->DetachCurrentThread(vm); +} + +static jint jni_GetEnv(JavaVM *vm, JNIEnv **env, jint version) { + return (*vm)->GetEnv(vm, (void **)env, version); +} + +static jclass jni_FindClass(JNIEnv *env, const char *name) { + return (*env)->FindClass(env, name); +} + +static jthrowable jni_ExceptionOccurred(JNIEnv *env) { + return (*env)->ExceptionOccurred(env); +} + +static void jni_ExceptionClear(JNIEnv *env) { + (*env)->ExceptionClear(env); +} + +static jclass jni_GetObjectClass(JNIEnv *env, jobject obj) { + return (*env)->GetObjectClass(env, obj); +} + +static jmethodID jni_GetMethodID(JNIEnv *env, jclass clazz, const char *name, const char *sig) { + return (*env)->GetMethodID(env, clazz, name, sig); +} + +static jmethodID jni_GetStaticMethodID(JNIEnv *env, jclass clazz, const char *name, const char *sig) { + return (*env)->GetStaticMethodID(env, clazz, name, sig); +} + +static jsize jni_GetStringLength(JNIEnv *env, jstring str) { + return (*env)->GetStringLength(env, str); +} + +static const jchar *jni_GetStringChars(JNIEnv *env, jstring str) { + return (*env)->GetStringChars(env, str, NULL); +} + +static jstring jni_NewString(JNIEnv *env, const jchar *unicodeChars, jsize len) { + return (*env)->NewString(env, unicodeChars, len); +} + +static jboolean jni_IsSameObject(JNIEnv *env, jobject ref1, jobject ref2) { + return (*env)->IsSameObject(env, ref1, ref2); +} + +static jobject jni_NewGlobalRef(JNIEnv *env, jobject obj) { + return (*env)->NewGlobalRef(env, obj); +} + +static void jni_DeleteGlobalRef(JNIEnv *env, jobject obj) { + (*env)->DeleteGlobalRef(env, obj); +} + +static void jni_CallStaticVoidMethodA(JNIEnv *env, jclass cls, jmethodID method, jvalue *args) { + (*env)->CallStaticVoidMethodA(env, cls, method, args); +} + +static jint jni_CallStaticIntMethodA(JNIEnv *env, jclass cls, jmethodID method, jvalue *args) { + return (*env)->CallStaticIntMethodA(env, cls, method, args); +} + +static jobject jni_CallStaticObjectMethodA(JNIEnv *env, jclass cls, jmethodID method, jvalue *args) { + return (*env)->CallStaticObjectMethodA(env, cls, method, args); +} + +static jobject jni_CallObjectMethodA(JNIEnv *env, jobject obj, jmethodID method, jvalue *args) { + return (*env)->CallObjectMethodA(env, obj, method, args); +} + +static jboolean jni_CallBooleanMethodA(JNIEnv *env, jobject obj, jmethodID method, jvalue *args) { + return (*env)->CallBooleanMethodA(env, obj, method, args); +} + +static jint jni_CallIntMethodA(JNIEnv *env, jobject obj, jmethodID method, jvalue *args) { + return (*env)->CallIntMethodA(env, obj, method, args); +} + +static void jni_CallVoidMethodA(JNIEnv *env, jobject obj, jmethodID method, jvalue *args) { + (*env)->CallVoidMethodA(env, obj, method, args); +} + +static jbyteArray jni_NewByteArray(JNIEnv *env, jsize length) { + return (*env)->NewByteArray(env, length); +} + +static jboolean *jni_GetBooleanArrayElements(JNIEnv *env, jbooleanArray arr) { + return (*env)->GetBooleanArrayElements(env, arr, NULL); +} + +static void jni_ReleaseBooleanArrayElements(JNIEnv *env, jbooleanArray arr, jboolean *elems, jint mode) { + (*env)->ReleaseBooleanArrayElements(env, arr, elems, mode); +} + +static jbyte *jni_GetByteArrayElements(JNIEnv *env, jbyteArray arr) { + return (*env)->GetByteArrayElements(env, arr, NULL); +} + +static jint *jni_GetIntArrayElements(JNIEnv *env, jintArray arr) { + return (*env)->GetIntArrayElements(env, arr, NULL); +} + +static void jni_ReleaseIntArrayElements(JNIEnv *env, jintArray arr, jint *elems, jint mode) { + (*env)->ReleaseIntArrayElements(env, arr, elems, mode); +} + +static jlong *jni_GetLongArrayElements(JNIEnv *env, jlongArray arr) { + return (*env)->GetLongArrayElements(env, arr, NULL); +} + +static void jni_ReleaseLongArrayElements(JNIEnv *env, jlongArray arr, jlong *elems, jint mode) { + (*env)->ReleaseLongArrayElements(env, arr, elems, mode); +} + +static void jni_ReleaseByteArrayElements(JNIEnv *env, jbyteArray arr, jbyte *elems, jint mode) { + (*env)->ReleaseByteArrayElements(env, arr, elems, mode); +} + +static jsize jni_GetArrayLength(JNIEnv *env, jarray arr) { + return (*env)->GetArrayLength(env, arr); +} + +static void jni_DeleteLocalRef(JNIEnv *env, jobject localRef) { + return (*env)->DeleteLocalRef(env, localRef); +} + +static jobject jni_GetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index) { + return (*env)->GetObjectArrayElement(env, array, index); +} + +static jboolean jni_IsInstanceOf(JNIEnv *env, jobject obj, jclass clazz) { + return (*env)->IsInstanceOf(env, obj, clazz); +} + +static jint jni_GetJavaVM(JNIEnv *env, JavaVM **jvm) { + return (*env)->GetJavaVM(env, jvm); +} +*/ +import "C" + +type JVM C.JavaVM + +type Env C.JNIEnv + +type ( + Class C.jclass + Object C.jobject + MethodID C.jmethodID + String C.jstring + ByteArray C.jbyteArray + ObjectArray C.jobjectArray + BooleanArray C.jbooleanArray + LongArray C.jlongArray + IntArray C.jintArray + Boolean C.jboolean + Value uint64 // All JNI types fit into 64-bits. +) + +// Cached class handles. +var classes struct { + once sync.Once + stringClass, integerClass Class + + integerIntValue MethodID +} + +func env(e *Env) *C.JNIEnv { + return (*C.JNIEnv)(unsafe.Pointer(e)) +} + +func javavm(vm *JVM) *C.JavaVM { + return (*C.JavaVM)(unsafe.Pointer(vm)) +} + +// Do invokes a function with a temporary JVM environment. The +// environment is not valid after the function returns. +func Do(vm *JVM, f func(env *Env) error) error { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + var env *C.JNIEnv + if res := C.jni_GetEnv(javavm(vm), &env, C.JNI_VERSION_1_6); res != C.JNI_OK { + if res != C.JNI_EDETACHED { + panic(fmt.Errorf("JNI GetEnv failed with error %d", res)) + } + if C.jni_AttachCurrentThread(javavm(vm), &env, nil) != C.JNI_OK { + panic(errors.New("runInJVM: AttachCurrentThread failed")) + } + defer C.jni_DetachCurrentThread(javavm(vm)) + } + + return f((*Env)(unsafe.Pointer(env))) +} + +func Bool(b bool) Boolean { + if b { + return C.JNI_TRUE + } + return C.JNI_FALSE +} + +func varArgs(args []Value) *C.jvalue { + if len(args) == 0 { + return nil + } + return (*C.jvalue)(unsafe.Pointer(&args[0])) +} + +func IsSameObject(e *Env, ref1, ref2 Object) bool { + same := C.jni_IsSameObject(env(e), C.jobject(ref1), C.jobject(ref2)) + return same == C.JNI_TRUE +} + +func CallStaticIntMethod(e *Env, cls Class, method MethodID, args ...Value) (int, error) { + res := C.jni_CallStaticIntMethodA(env(e), C.jclass(cls), C.jmethodID(method), varArgs(args)) + return int(res), exception(e) +} + +func CallStaticVoidMethod(e *Env, cls Class, method MethodID, args ...Value) error { + C.jni_CallStaticVoidMethodA(env(e), C.jclass(cls), C.jmethodID(method), varArgs(args)) + return exception(e) +} + +func CallVoidMethod(e *Env, obj Object, method MethodID, args ...Value) error { + C.jni_CallVoidMethodA(env(e), C.jobject(obj), C.jmethodID(method), varArgs(args)) + return exception(e) +} + +func CallStaticObjectMethod(e *Env, cls Class, method MethodID, args ...Value) (Object, error) { + res := C.jni_CallStaticObjectMethodA(env(e), C.jclass(cls), C.jmethodID(method), varArgs(args)) + return Object(res), exception(e) +} + +func CallObjectMethod(e *Env, obj Object, method MethodID, args ...Value) (Object, error) { + res := C.jni_CallObjectMethodA(env(e), C.jobject(obj), C.jmethodID(method), varArgs(args)) + return Object(res), exception(e) +} + +func CallBooleanMethod(e *Env, obj Object, method MethodID, args ...Value) (bool, error) { + res := C.jni_CallBooleanMethodA(env(e), C.jobject(obj), C.jmethodID(method), varArgs(args)) + return res == C.JNI_TRUE, exception(e) +} + +func CallIntMethod(e *Env, obj Object, method MethodID, args ...Value) (int32, error) { + res := C.jni_CallIntMethodA(env(e), C.jobject(obj), C.jmethodID(method), varArgs(args)) + return int32(res), exception(e) +} + +func GetArrayLength(e *Env, jarr ByteArray) int { + size := C.jni_GetArrayLength(env(e), C.jarray(jarr)) + return int(size) +} + +// GetByteArrayElements returns the contents of the byte array. +func GetByteArrayElements(e *Env, jarr ByteArray) []byte { + if jarr == 0 { + return nil + } + size := C.jni_GetArrayLength(env(e), C.jarray(jarr)) + elems := C.jni_GetByteArrayElements(env(e), C.jbyteArray(jarr)) + defer C.jni_ReleaseByteArrayElements(env(e), C.jbyteArray(jarr), elems, 0) + backing := (*(*[1 << 30]byte)(unsafe.Pointer(elems)))[:size:size] + s := make([]byte, len(backing)) + copy(s, backing) + return s +} + +// GetBooleanArrayElements returns the contents of the boolean array. +func GetBooleanArrayElements(e *Env, jarr BooleanArray) []bool { + if jarr == 0 { + return nil + } + size := C.jni_GetArrayLength(env(e), C.jarray(jarr)) + elems := C.jni_GetBooleanArrayElements(env(e), C.jbooleanArray(jarr)) + defer C.jni_ReleaseBooleanArrayElements(env(e), C.jbooleanArray(jarr), elems, 0) + backing := (*(*[1 << 30]C.jboolean)(unsafe.Pointer(elems)))[:size:size] + r := make([]bool, len(backing)) + for i, b := range backing { + r[i] = b == C.JNI_TRUE + } + return r +} + +// GetStringArrayElements returns the contents of the String array. +func GetStringArrayElements(e *Env, jarr ObjectArray) []string { + var strings []string + iterateObjectArray(e, jarr, func(e *Env, idx int, item Object) { + s := GoString(e, String(item)) + strings = append(strings, s) + }) + return strings +} + +// GetIntArrayElements returns the contents of the int array. +func GetIntArrayElements(e *Env, jarr IntArray) []int { + if jarr == 0 { + return nil + } + size := C.jni_GetArrayLength(env(e), C.jarray(jarr)) + elems := C.jni_GetIntArrayElements(env(e), C.jintArray(jarr)) + defer C.jni_ReleaseIntArrayElements(env(e), C.jintArray(jarr), elems, 0) + backing := (*(*[1 << 27]C.jint)(unsafe.Pointer(elems)))[:size:size] + r := make([]int, len(backing)) + for i, l := range backing { + r[i] = int(l) + } + return r +} + +// GetLongArrayElements returns the contents of the long array. +func GetLongArrayElements(e *Env, jarr LongArray) []int64 { + if jarr == 0 { + return nil + } + size := C.jni_GetArrayLength(env(e), C.jarray(jarr)) + elems := C.jni_GetLongArrayElements(env(e), C.jlongArray(jarr)) + defer C.jni_ReleaseLongArrayElements(env(e), C.jlongArray(jarr), elems, 0) + backing := (*(*[1 << 27]C.jlong)(unsafe.Pointer(elems)))[:size:size] + r := make([]int64, len(backing)) + for i, l := range backing { + r[i] = int64(l) + } + return r +} + +func iterateObjectArray(e *Env, jarr ObjectArray, f func(e *Env, idx int, item Object)) { + if jarr == 0 { + return + } + size := C.jni_GetArrayLength(env(e), C.jarray(jarr)) + for i := 0; i < int(size); i++ { + item := C.jni_GetObjectArrayElement(env(e), C.jobjectArray(jarr), C.jint(i)) + f(e, i, Object(item)) + C.jni_DeleteLocalRef(env(e), item) + } +} + +// NewByteArray allocates a Java byte array with the content. It +// panics if the allocation fails. +func NewByteArray(e *Env, content []byte) ByteArray { + jarr := C.jni_NewByteArray(env(e), C.jsize(len(content))) + if jarr == 0 { + panic(fmt.Errorf("jni: NewByteArray(%d) failed", len(content))) + } + elems := C.jni_GetByteArrayElements(env(e), jarr) + defer C.jni_ReleaseByteArrayElements(env(e), jarr, elems, 0) + backing := (*(*[1 << 30]byte)(unsafe.Pointer(elems)))[:len(content):len(content)] + copy(backing, content) + return ByteArray(jarr) +} + +// ClassLoader returns a reference to the Java ClassLoader associated +// with obj. +func ClassLoaderFor(e *Env, obj Object) Object { + cls := GetObjectClass(e, obj) + getClassLoader := GetMethodID(e, cls, "getClassLoader", "()Ljava/lang/ClassLoader;") + clsLoader, err := CallObjectMethod(e, Object(obj), getClassLoader) + if err != nil { + // Class.getClassLoader should never fail. + panic(err) + } + return Object(clsLoader) +} + +// LoadClass invokes the underlying ClassLoader's loadClass method and +// returns the class. +func LoadClass(e *Env, loader Object, class string) (Class, error) { + cls := GetObjectClass(e, loader) + loadClass := GetMethodID(e, cls, "loadClass", "(Ljava/lang/String;)Ljava/lang/Class;") + name := JavaString(e, class) + loaded, err := CallObjectMethod(e, loader, loadClass, Value(name)) + if err != nil { + return 0, err + } + return Class(loaded), exception(e) +} + +// exception returns an error corresponding to the pending +// exception, and clears it. exceptionError returns nil if no +// exception is pending. +func exception(e *Env) error { + thr := C.jni_ExceptionOccurred(env(e)) + if thr == 0 { + return nil + } + C.jni_ExceptionClear(env(e)) + cls := GetObjectClass(e, Object(thr)) + toString := GetMethodID(e, cls, "toString", "()Ljava/lang/String;") + msg, err := CallObjectMethod(e, Object(thr), toString) + if err != nil { + return err + } + return errors.New(GoString(e, String(msg))) +} + +// GetObjectClass returns the Java Class for an Object. +func GetObjectClass(e *Env, obj Object) Class { + if obj == 0 { + panic("null object") + } + cls := C.jni_GetObjectClass(env(e), C.jobject(obj)) + if err := exception(e); err != nil { + // GetObjectClass should never fail. + panic(err) + } + return Class(cls) +} + +// GetStaticMethodID returns the id for a static method. It panics if the method +// wasn't found. +func GetStaticMethodID(e *Env, cls Class, name, signature string) MethodID { + mname := C.CString(name) + defer C.free(unsafe.Pointer(mname)) + msig := C.CString(signature) + defer C.free(unsafe.Pointer(msig)) + m := C.jni_GetStaticMethodID(env(e), C.jclass(cls), mname, msig) + if err := exception(e); err != nil { + panic(err) + } + return MethodID(m) +} + +// GetMethodID returns the id for a method. It panics if the method +// wasn't found. +func GetMethodID(e *Env, cls Class, name, signature string) MethodID { + mname := C.CString(name) + defer C.free(unsafe.Pointer(mname)) + msig := C.CString(signature) + defer C.free(unsafe.Pointer(msig)) + m := C.jni_GetMethodID(env(e), C.jclass(cls), mname, msig) + if err := exception(e); err != nil { + panic(err) + } + return MethodID(m) +} + +func NewGlobalRef(e *Env, obj Object) Object { + return Object(C.jni_NewGlobalRef(env(e), C.jobject(obj))) +} + +func DeleteGlobalRef(e *Env, obj Object) { + C.jni_DeleteGlobalRef(env(e), C.jobject(obj)) +} + +// JavaString converts the string to a JVM jstring. +func JavaString(e *Env, str string) String { + if str == "" { + return 0 + } + utf16Chars := utf16.Encode([]rune(str)) + res := C.jni_NewString(env(e), (*C.jchar)(unsafe.Pointer(&utf16Chars[0])), C.int(len(utf16Chars))) + return String(res) +} + +// GoString converts the JVM jstring to a Go string. +func GoString(e *Env, str String) string { + if str == 0 { + return "" + } + strlen := C.jni_GetStringLength(env(e), C.jstring(str)) + chars := C.jni_GetStringChars(env(e), C.jstring(str)) + var utf16Chars []uint16 + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&utf16Chars)) + hdr.Data = uintptr(unsafe.Pointer(chars)) + hdr.Cap = int(strlen) + hdr.Len = int(strlen) + utf8 := utf16.Decode(utf16Chars) + return string(utf8) +} + +func GetJavaVM(e *Env) (*JVM, error) { + var jvm *C.JavaVM + result := C.jni_GetJavaVM(env(e), &jvm) + if result != C.JNI_OK { + return nil, errors.New("failed to get JavaVM") + } + return (*JVM)(jvm), nil +} diff --git a/pkg/localapiservice/localapi_test.go b/pkg/localapiservice/localapi_test.go new file mode 100644 index 0000000..97f0d06 --- /dev/null +++ b/pkg/localapiservice/localapi_test.go @@ -0,0 +1,76 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package localapiservice + +import ( + "context" + "io" + "net/http" + "testing" + "time" +) + +var ctx = context.Background() + +type BadStatusHandler struct{} + +func (b *BadStatusHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) +} + +func TestBadStatus(t *testing.T) { + ctx, cancel := context.WithDeadline(ctx, time.Now().Add(2*time.Second)) + client := New(&BadStatusHandler{}) + defer cancel() + + _, err := client.Call(ctx, "POST", "test", nil) + + if err.Error() != "request failed with status code 400" { + t.Error("Expected bad status error, but got", err) + } +} + +type TimeoutHandler struct{} + +var successfulResponse = "successful response!" + +func (b *TimeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + time.Sleep(6 * time.Second) + w.Write([]byte(successfulResponse)) +} + +func TestTimeout(t *testing.T) { + ctx, cancel := context.WithDeadline(ctx, time.Now().Add(2*time.Second)) + client := New(&TimeoutHandler{}) + defer cancel() + + _, err := client.Call(ctx, "GET", "test", nil) + + if err.Error() != "timeout for test" { + t.Error("Expected timeout error, but got", err) + } +} + +type SuccessfulHandler struct{} + +func (b *SuccessfulHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(successfulResponse)) +} + +func TestSuccess(t *testing.T) { + ctx, cancel := context.WithDeadline(ctx, time.Now().Add(2*time.Second)) + client := New(&SuccessfulHandler{}) + defer cancel() + + w, err := client.Call(ctx, "GET", "test", nil) + + if err != nil { + t.Error("Expected no error, but got", err) + } + + report, err := io.ReadAll(w.Body()) + if string(report) != successfulResponse { + t.Error("Expected successful report, but got", report) + } +} diff --git a/pkg/localapiservice/localapiservice.go b/pkg/localapiservice/localapiservice.go new file mode 100644 index 0000000..6fa88f5 --- /dev/null +++ b/pkg/localapiservice/localapiservice.go @@ -0,0 +1,113 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package localapiservice + +import ( + "context" + "fmt" + "io" + "log" + "net" + "net/http" + "time" + + "tailscale.com/ipn/ipnlocal" +) + +type LocalAPIService struct { + h http.Handler +} + +func New(h http.Handler) *LocalAPIService { + return &LocalAPIService{h: h} +} + +// Call calls the given endpoint on the local API using the given HTTP method +// optionally sending the given body. It returns a Response representing the +// result of the call and an error if the call could not be completed or the +// local API returned a status code in the 400 series or greater. +// Note - Response includes a response body available from the Body method, it +// is the caller's responsibility to close this. +func (cl *LocalAPIService) Call(ctx context.Context, method, endpoint string, body io.Reader) (*Response, error) { + req, err := http.NewRequestWithContext(ctx, method, endpoint, body) + if err != nil { + return nil, fmt.Errorf("error creating new request for %s: %w", endpoint, err) + } + deadline, _ := ctx.Deadline() + pipeReader, pipeWriter := net.Pipe() + pipeReader.SetDeadline(deadline) + pipeWriter.SetDeadline(deadline) + + resp := &Response{ + headers: http.Header{}, + status: http.StatusOK, + bodyReader: pipeReader, + bodyWriter: pipeWriter, + startWritingBody: make(chan interface{}), + } + + go func() { + cl.h.ServeHTTP(resp, req) + resp.Flush() + pipeWriter.Close() + }() + + select { + case <-resp.startWritingBody: + if resp.StatusCode() >= 400 { + return resp, fmt.Errorf("request failed with status code %d", resp.StatusCode()) + } + return resp, nil + case <-ctx.Done(): + return nil, fmt.Errorf("timeout for %s", endpoint) + } +} + +func (s *LocalAPIService) GetBugReportID(ctx context.Context, bugReportChan chan<- string, fallbackLog string) { + ctx, cancel := context.WithTimeout(ctx, 2*time.Second) + defer cancel() + r, err := s.Call(ctx, "POST", "/localapi/v0/bugreport", nil) + defer r.Body().Close() + + if err != nil { + log.Printf("get bug report: %s", err) + bugReportChan <- fallbackLog + return + } + logBytes, err := io.ReadAll(r.Body()) + if err != nil { + log.Printf("read bug report: %s", err) + bugReportChan <- fallbackLog + return + } + bugReportChan <- string(logBytes) +} + +func (s *LocalAPIService) Login(ctx context.Context, backend *ipnlocal.LocalBackend) { + ctx, cancel := context.WithTimeout(ctx, 60*time.Second) + defer cancel() + r, err := s.Call(ctx, "POST", "/localapi/v0/login-interactive", nil) + defer r.Body().Close() + + if err != nil { + log.Printf("login: %s", err) + backend.StartLoginInteractive() + } +} + +func (s *LocalAPIService) Logout(ctx context.Context, backend *ipnlocal.LocalBackend) error { + ctx, cancel := context.WithTimeout(ctx, 60*time.Second) + defer cancel() + r, err := s.Call(ctx, "POST", "/localapi/v0/logout", nil) + defer r.Body().Close() + + if err != nil { + log.Printf("logout: %s", err) + logoutctx, logoutcancel := context.WithTimeout(ctx, 5*time.Minute) + defer logoutcancel() + backend.Logout(logoutctx) + } + + return err +} diff --git a/pkg/localapiservice/localapishim.go b/pkg/localapiservice/localapishim.go new file mode 100644 index 0000000..c16c081 --- /dev/null +++ b/pkg/localapiservice/localapishim.go @@ -0,0 +1,202 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package localapiservice + +import ( + "bytes" + "context" + "encoding/json" + "io" + "log" + "time" + "unsafe" + + jnipkg "github.com/tailscale/tailscale-android/pkg/jni" + "tailscale.com/ipn" + "tailscale.com/ipn/ipnlocal" +) + +// #include +import "C" + +// Shims the LocalApiClient class from the Kotlin side to the Go side's LocalAPIService. +var shim struct { + // localApiClient is a global reference to the com.tailscale.ipn.ui.localapi.LocalApiClient class. + clientClass jnipkg.Class + + // notifierClass is a global reference to the com.tailscale.ipn.ui.notifier.Notifier class. + notifierClass jnipkg.Class + + // Typically a shared LocalAPIService instance. + service *LocalAPIService + + backend *ipnlocal.LocalBackend + + busWatchers map[string]func() + + jvm *jnipkg.JVM +} + +//export Java_com_tailscale_ipn_ui_localapi_LocalApiClient_doRequest +func Java_com_tailscale_ipn_ui_localapi_LocalApiClient_doRequest( + env *C.JNIEnv, + cls C.jclass, + jpath C.jstring, + jmethod C.jstring, + jbody C.jbyteArray, + jcookie C.jstring) { + + jenv := (*jnipkg.Env)(unsafe.Pointer(env)) + + // The API Path + pathRef := jnipkg.NewGlobalRef(jenv, jnipkg.Object(jpath)) + pathStr := jnipkg.GoString(jenv, jnipkg.String(pathRef)) + defer jnipkg.DeleteGlobalRef(jenv, pathRef) + + // The HTTP verb + methodRef := jnipkg.NewGlobalRef(jenv, jnipkg.Object(jmethod)) + methodStr := jnipkg.GoString(jenv, jnipkg.String(methodRef)) + defer jnipkg.DeleteGlobalRef(jenv, methodRef) + + // The body string. This is optional and may be empty. + bodyRef := jnipkg.NewGlobalRef(jenv, jnipkg.Object(jbody)) + bodyArray := jnipkg.GetByteArrayElements(jenv, jnipkg.ByteArray(bodyRef)) + defer jnipkg.DeleteGlobalRef(jenv, bodyRef) + + resp := doLocalAPIRequest(pathStr, methodStr, bodyArray) + + jrespBody := jnipkg.NewByteArray(jenv, resp) + respBody := jnipkg.Value(jrespBody) + cookie := jnipkg.Value(jcookie) + onResponse := jnipkg.GetMethodID(jenv, shim.clientClass, "onResponse", "([BLjava/lang/String;)V") + + jnipkg.CallVoidMethod(jenv, jnipkg.Object(cls), onResponse, respBody, cookie) +} + +func doLocalAPIRequest(path string, method string, body []byte) []byte { + if shim.service == nil { + return []byte("{\"error\":\"Not Ready\"}") + } + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + var reader io.Reader = nil + if len(body) > 0 { + reader = bytes.NewReader(body) + } + + r, err := shim.service.Call(ctx, method, path, reader) + defer r.Body().Close() + + if err != nil { + return []byte("{\"error\":\"" + err.Error() + "\"}") + } + respBytes, err := io.ReadAll(r.Body()) + if err != nil { + return []byte("{\"error\":\"" + err.Error() + "\"}") + } + return respBytes +} + +// Assign a localAPIService to our shim for handling incoming localapi requests from the Kotlin side. +func ConfigureShim(jvm *jnipkg.JVM, appCtx jnipkg.Object, s *LocalAPIService, b *ipnlocal.LocalBackend) { + shim.busWatchers = make(map[string]func()) + shim.service = s + shim.backend = b + + configureLocalApiJNIHandler(jvm, appCtx) + + // Let the Kotlin side know we're ready to handle requests. + jnipkg.Do(jvm, func(env *jnipkg.Env) error { + onReadyAPI := jnipkg.GetStaticMethodID(env, shim.clientClass, "onReady", "()V") + jnipkg.CallStaticVoidMethod(env, shim.clientClass, onReadyAPI) + + onNotifyNot := jnipkg.GetStaticMethodID(env, shim.notifierClass, "onReady", "()V") + jnipkg.CallStaticVoidMethod(env, shim.notifierClass, onNotifyNot) + + log.Printf("LocalAPI Shim ready") + return nil + }) +} + +// Loads the Kotlin-side LocalApiClient class and stores it in a global reference. +func configureLocalApiJNIHandler(jvm *jnipkg.JVM, appCtx jnipkg.Object) error { + shim.jvm = jvm + + return jnipkg.Do(jvm, func(env *jnipkg.Env) error { + loader := jnipkg.ClassLoaderFor(env, appCtx) + cl, err := jnipkg.LoadClass(env, loader, "com.tailscale.ipn.ui.localapi.LocalApiClient") + if err != nil { + return err + } + shim.clientClass = jnipkg.Class(jnipkg.NewGlobalRef(env, jnipkg.Object(cl))) + + cl, err = jnipkg.LoadClass(env, loader, "com.tailscale.ipn.ui.notifier.Notifier") + if err != nil { + return err + } + shim.notifierClass = jnipkg.Class(jnipkg.NewGlobalRef(env, jnipkg.Object(cl))) + + return nil + }) +} + +//export Java_com_tailscale_ipn_ui_notifier_Notifier_stopIPNBusWatcher +func Java_com_tailscale_ipn_ui_notifier_Notifier_stopIPNBusWatcher( + env *C.JNIEnv, + cls C.jclass, + jsessionId C.jstring) { + + jenv := (*jnipkg.Env)(unsafe.Pointer(env)) + + sessionIdRef := jnipkg.NewGlobalRef(jenv, jnipkg.Object(jsessionId)) + sessionId := jnipkg.GoString(jenv, jnipkg.String(sessionIdRef)) + defer jnipkg.DeleteGlobalRef(jenv, sessionIdRef) + + cancel := shim.busWatchers[sessionId] + if cancel != nil { + log.Printf("Deregistering app layer bus watcher with sessionid: %s", sessionId) + cancel() + delete(shim.busWatchers, sessionId) + } else { + log.Printf("Error: Could not find bus watcher with sessionid: %s", sessionId) + } +} + +//export Java_com_tailscale_ipn_ui_notifier_Notifier_startIPNBusWatcher +func Java_com_tailscale_ipn_ui_notifier_Notifier_startIPNBusWatcher( + env *C.JNIEnv, + cls C.jclass, + jsessionId C.jstring, + jmask C.jint) { + + jenv := (*jnipkg.Env)(unsafe.Pointer(env)) + + sessionIdRef := jnipkg.NewGlobalRef(jenv, jnipkg.Object(jsessionId)) + sessionId := jnipkg.GoString(jenv, jnipkg.String(sessionIdRef)) + defer jnipkg.DeleteGlobalRef(jenv, sessionIdRef) + + log.Printf("Registering app layer bus watcher with sessionid: %s", sessionId) + + ctx, cancel := context.WithCancel(context.Background()) + shim.busWatchers[sessionId] = cancel + opts := ipn.NotifyWatchOpt(jmask) + + shim.backend.WatchNotifications(ctx, opts, func() { + // onWatchAdded + }, func(roNotify *ipn.Notify) bool { + js, err := json.Marshal(roNotify) + if err != nil { + return true + } + jnipkg.Do(shim.jvm, func(env *jnipkg.Env) error { + jjson := jnipkg.JavaString(env, string(js)) + onNotify := jnipkg.GetMethodID(env, shim.notifierClass, "onNotify", "(Ljava/lang/String;Ljava/lang/String;)V") + jnipkg.CallVoidMethod(env, jnipkg.Object(cls), onNotify, jnipkg.Value(jjson), jnipkg.Value(jsessionId)) + return nil + }) + return true + }) + +} diff --git a/pkg/localapiservice/response.go b/pkg/localapiservice/response.go new file mode 100644 index 0000000..9e30ebc --- /dev/null +++ b/pkg/localapiservice/response.go @@ -0,0 +1,53 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package localapiservice + +import ( + "net" + "net/http" + "sync" +) + +// Response represents the result of processing an localAPI request. +// On completion, the response body can be read out of the bodyWriter. +type Response struct { + headers http.Header + status int + bodyWriter net.Conn + bodyReader net.Conn + startWritingBody chan interface{} + startWritingBodyOnce sync.Once +} + +func (r *Response) Header() http.Header { + return r.headers +} + +// Write writes the data to the response body which an then be +// read out as a json object. +func (r *Response) Write(data []byte) (int, error) { + r.Flush() + if r.status == 0 { + r.WriteHeader(http.StatusOK) + } + return r.bodyWriter.Write(data) +} + +func (r *Response) WriteHeader(statusCode int) { + r.status = statusCode +} + +func (r *Response) Body() net.Conn { + return r.bodyReader +} + +func (r *Response) StatusCode() int { + return r.status +} + +func (r *Response) Flush() { + r.startWritingBodyOnce.Do(func() { + close(r.startWritingBody) + }) +} diff --git a/pkg/tailscale/app.go b/pkg/tailscale/app.go new file mode 100644 index 0000000..e77f520 --- /dev/null +++ b/pkg/tailscale/app.go @@ -0,0 +1,85 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "os" + "path/filepath" + "sync" + "sync/atomic" + "unsafe" + + jnipkg "github.com/tailscale/tailscale-android/pkg/jni" + "github.com/tailscale/tailscale-android/pkg/localapiservice" + "tailscale.com/ipn/ipnlocal" + "tailscale.com/types/logid" +) + +// #include +import "C" + +type App struct { + jvm *jnipkg.JVM + // appCtx is a global reference to the com.tailscale.ipn.App instance. + appCtx jnipkg.Object + + store *stateStore + logIDPublicAtomic atomic.Pointer[logid.PublicID] + + localAPI *localapiservice.LocalAPIService + backend *ipnlocal.LocalBackend +} + +var android struct { + // mu protects all fields of this structure. However, once a + // non-nil jvm is returned from javaVM, all the other fields may + // be accessed unlocked. + mu sync.Mutex + jvm *jnipkg.JVM + + // appCtx is the global Android App context. + appCtx C.jobject +} + +func initJVM(env *C.JNIEnv, ctx C.jobject) { + android.mu.Lock() + defer android.mu.Unlock() + jenv := (*jnipkg.Env)(unsafe.Pointer(env)) + res, err := jnipkg.GetJavaVM(jenv) + if err != nil { + panic("eror: GetJavaVM failed") + } + android.jvm = res + android.appCtx = C.jobject(jnipkg.NewGlobalRef(jenv, jnipkg.Object(ctx))) +} + +//export Java_com_tailscale_ipn_App_initBackend +func Java_com_tailscale_ipn_App_initBackend(env *C.JNIEnv, class C.jclass, jdataDir C.jbyteArray, context C.jobject) { + initJVM(env, context) + jenv := (*jnipkg.Env)(unsafe.Pointer(env)) + dirBytes := jnipkg.GetByteArrayElements(jenv, jnipkg.ByteArray(jdataDir)) + if dirBytes == nil { + panic("runGoMain: GetByteArrayElements failed") + } + n := jnipkg.GetArrayLength(jenv, jnipkg.ByteArray(jdataDir)) + dataDir := C.GoStringN((*C.char)(unsafe.Pointer(&dirBytes[0])), C.int(n)) + + // Set XDG_CACHE_HOME to make os.UserCacheDir work. + if _, exists := os.LookupEnv("XDG_CACHE_HOME"); !exists { + cachePath := filepath.Join(dataDir, "cache") + os.Setenv("XDG_CACHE_HOME", cachePath) + } + // Set XDG_CONFIG_HOME to make os.UserConfigDir work. + if _, exists := os.LookupEnv("XDG_CONFIG_HOME"); !exists { + cfgPath := filepath.Join(dataDir, "config") + os.Setenv("XDG_CONFIG_HOME", cfgPath) + } + // Set HOME to make os.UserHomeDir work. + if _, exists := os.LookupEnv("HOME"); !exists { + os.Setenv("HOME", dataDir) + } + + dataDirChan <- dataDir + main() +} diff --git a/pkg/tailscale/backend.go b/pkg/tailscale/backend.go new file mode 100644 index 0000000..51b5e20 --- /dev/null +++ b/pkg/tailscale/backend.go @@ -0,0 +1,285 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "context" + "fmt" + "log" + + jnipkg "github.com/tailscale/tailscale-android/pkg/jni" + "github.com/tailscale/tailscale-android/pkg/localapiservice" + "tailscale.com/hostinfo" + "tailscale.com/ipn" + "tailscale.com/ipn/ipnlocal" + "tailscale.com/ipn/localapi" + "tailscale.com/logtail" + "tailscale.com/net/dns" + "tailscale.com/net/netmon" + "tailscale.com/net/netns" + "tailscale.com/net/tsdial" + "tailscale.com/paths" + "tailscale.com/tsd" + "tailscale.com/types/logger" + "tailscale.com/types/logid" + "tailscale.com/types/netmap" + "tailscale.com/wgengine" + "tailscale.com/wgengine/netstack" + "tailscale.com/wgengine/router" +) + +import "C" + +type BackendState struct { + State ipn.State + NetworkMap *netmap.NetworkMap + LostInternet bool +} + +type backend struct { + engine wgengine.Engine + backend *ipnlocal.LocalBackend + sys *tsd.System + devices *multiTUN + settings settingsFunc + lastCfg *router.Config + lastDNSCfg *dns.OSConfig + netMon *netmon.Monitor + + logIDPublic logid.PublicID + logger *logtail.Logger + + // avoidEmptyDNS controls whether to use fallback nameservers + // when no nameservers are provided by Tailscale. + avoidEmptyDNS bool + + jvm *jnipkg.JVM + appCtx jnipkg.Object +} + +type settingsFunc func(*router.Config, *dns.OSConfig) error + +func (a *App) runBackend(ctx context.Context) error { + appDir, err := dataDir() + if err != nil { + fatalErr(err) + } + paths.AppSharedDir.Store(appDir) + hostinfo.SetOSVersion(a.osVersion()) + if !googleSignInEnabled() { + hostinfo.SetPackage("nogoogle") + } + deviceModel := a.modelName() + if a.isChromeOS() { + deviceModel = "ChromeOS: " + deviceModel + } + hostinfo.SetDeviceModel(deviceModel) + + type configPair struct { + rcfg *router.Config + dcfg *dns.OSConfig + } + configs := make(chan configPair) + configErrs := make(chan error) + b, err := newBackend(appDir, a.jvm, a.appCtx, a.store, func(rcfg *router.Config, dcfg *dns.OSConfig) error { + if rcfg == nil { + return nil + } + configs <- configPair{rcfg, dcfg} + return <-configErrs + }) + if err != nil { + return err + } + a.logIDPublicAtomic.Store(&b.logIDPublic) + a.backend = b.backend + defer b.CloseTUNs() + + h := localapi.NewHandler(b.backend, log.Printf, b.sys.NetMon.Get(), *a.logIDPublicAtomic.Load()) + h.PermitRead = true + h.PermitWrite = true + a.localAPI = localapiservice.New(h) + + // Share the localAPI with the JNI shim + //localapiservice.SetLocalAPIService(a.localAPI) + localapiservice.ConfigureShim(a.jvm, a.appCtx, a.localAPI, b.backend) + + // Contrary to the documentation for VpnService.Builder.addDnsServer, + // ChromeOS doesn't fall back to the underlying network nameservers if + // we don't provide any. + b.avoidEmptyDNS = a.isChromeOS() + + var ( + cfg configPair + state BackendState + service jnipkg.Object // of IPNService + ) + for { + select { + case c := <-configs: + cfg = c + if b == nil || service == 0 || cfg.rcfg == nil { + configErrs <- nil + break + } + configErrs <- b.updateTUN(service, cfg.rcfg, cfg.dcfg) + case s := <-onVPNRequested: + jnipkg.Do(a.jvm, func(env *jnipkg.Env) error { + if jnipkg.IsSameObject(env, s, service) { + // We already have a reference. + jnipkg.DeleteGlobalRef(env, s) + return nil + } + if service != 0 { + jnipkg.DeleteGlobalRef(env, service) + } + netns.SetAndroidProtectFunc(func(fd int) error { + return jnipkg.Do(a.jvm, func(env *jnipkg.Env) error { + // Call https://developer.android.com/reference/android/net/VpnService#protect(int) + // to mark fd as a socket that should bypass the VPN and use the underlying network. + cls := jnipkg.GetObjectClass(env, s) + m := jnipkg.GetMethodID(env, cls, "protect", "(I)Z") + ok, err := jnipkg.CallBooleanMethod(env, s, m, jnipkg.Value(fd)) + // TODO(bradfitz): return an error back up to netns if this fails, once + // we've had some experience with this and analyzed the logs over a wide + // range of Android phones. For now we're being paranoid and conservative + // and do the JNI call to protect best effort, only logging if it fails. + // The risk of returning an error is that it breaks users on some Android + // versions even when they're not using exit nodes. I'd rather the + // relatively few number of exit node users file bug reports if Tailscale + // doesn't work and then we can look for this log print. + if err != nil || !ok { + log.Printf("[unexpected] VpnService.protect(%d) = %v, %v", fd, ok, err) + } + return nil // even on error. see big TODO above. + }) + }) + log.Printf("onVPNRequested: rebind required") + // TODO(catzkorn): When we start the android application + // we bind sockets before we have access to the VpnService.protect() + // function which is needed to avoid routing loops. When we activate + // the service we get access to the protect, but do not retrospectively + // protect the sockets already opened, which breaks connectivity. + // As a temporary fix, we rebind and protect the magicsock.Conn on connect + // which restores connectivity. + // See https://github.com/tailscale/corp/issues/13814 + b.backend.DebugRebind() + + service = s + return nil + }) + if m := state.NetworkMap; m != nil { + // TODO + } + if cfg.rcfg != nil && state.State >= ipn.Starting { + if err := b.updateTUN(service, cfg.rcfg, cfg.dcfg); err != nil { + log.Printf("VPN update failed: %v", err) + notifyVPNClosed() + } + } + case s := <-onDisconnect: + b.CloseTUNs() + jnipkg.Do(a.jvm, func(env *jnipkg.Env) error { + defer jnipkg.DeleteGlobalRef(env, s) + if jnipkg.IsSameObject(env, service, s) { + netns.SetAndroidProtectFunc(nil) + jnipkg.DeleteGlobalRef(env, service) + service = 0 + } + return nil + }) + if state.State >= ipn.Starting { + notifyVPNClosed() + } + case <-onDNSConfigChanged: + if b != nil { + go b.NetworkChanged() + } + } + } +} + +func newBackend(dataDir string, jvm *jnipkg.JVM, appCtx jnipkg.Object, store *stateStore, + settings settingsFunc) (*backend, error) { + + sys := new(tsd.System) + sys.Set(store) + + logf := logger.RusagePrefixLog(log.Printf) + b := &backend{ + jvm: jvm, + devices: newTUNDevices(), + settings: settings, + appCtx: appCtx, + } + var logID logid.PrivateID + logID.UnmarshalText([]byte("dead0000dead0000dead0000dead0000dead0000dead0000dead0000dead0000")) + storedLogID, err := store.read(logPrefKey) + // In all failure cases we ignore any errors and continue with the dead value above. + if err != nil || storedLogID == nil { + // Read failed or there was no previous log id. + newLogID, err := logid.NewPrivateID() + if err == nil { + logID = newLogID + enc, err := newLogID.MarshalText() + if err == nil { + store.write(logPrefKey, enc) + } + } + } else { + logID.UnmarshalText([]byte(storedLogID)) + } + + netMon, err := netmon.New(logf) + if err != nil { + log.Printf("netmon.New: %w", err) + } + b.netMon = netMon + b.setupLogs(dataDir, logID, logf) + dialer := new(tsdial.Dialer) + cb := &router.CallbackRouter{ + SetBoth: b.setCfg, + SplitDNS: false, + GetBaseConfigFunc: b.getDNSBaseConfig, + } + engine, err := wgengine.NewUserspaceEngine(logf, wgengine.Config{ + Tun: b.devices, + Router: cb, + DNS: cb, + Dialer: dialer, + SetSubsystem: sys.Set, + NetMon: b.netMon, + }) + if err != nil { + return nil, fmt.Errorf("runBackend: NewUserspaceEngine: %v", err) + } + sys.Set(engine) + b.logIDPublic = logID.Public() + ns, err := netstack.Create(logf, sys.Tun.Get(), engine, sys.MagicSock.Get(), dialer, sys.DNSManager.Get(), sys.ProxyMapper(), nil) + if err != nil { + return nil, fmt.Errorf("netstack.Create: %w", err) + } + sys.Set(ns) + ns.ProcessLocalIPs = false // let Android kernel handle it; VpnBuilder sets this up + ns.ProcessSubnets = true // for Android-being-an-exit-node support + sys.NetstackRouter.Set(true) + if w, ok := sys.Tun.GetOK(); ok { + w.Start() + } + lb, err := ipnlocal.NewLocalBackend(logf, logID.Public(), sys, 0) + if err != nil { + engine.Close() + return nil, fmt.Errorf("runBackend: NewLocalBackend: %v", err) + } + if err := ns.Start(lb); err != nil { + return nil, fmt.Errorf("startNetstack: %w", err) + } + if b.logger != nil { + lb.SetLogFlusher(b.logger.StartFlush) + } + b.engine = engine + b.backend = lb + b.sys = sys + return b, nil +} diff --git a/pkg/tailscale/callbacks.go b/pkg/tailscale/callbacks.go new file mode 100644 index 0000000..25f5eb9 --- /dev/null +++ b/pkg/tailscale/callbacks.go @@ -0,0 +1,151 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "sync" + "unsafe" + + jnipkg "github.com/tailscale/tailscale-android/pkg/jni" +) + +// #include +import "C" + +var ( + // onVPNPrepared is notified when VpnService.prepare succeeds. + onVPNPrepared = make(chan struct{}, 1) + // onVPNClosed is notified when VpnService.prepare fails, or when + // the a running VPN connection is closed. + onVPNClosed = make(chan struct{}, 1) + // onVPNRevoked is notified whenever the VPN service is revoked. + onVPNRevoked = make(chan struct{}, 1) + + // onVPNRequested receives global IPNService references when + // a VPN connection is requested. + onVPNRequested = make(chan jnipkg.Object) + // onDisconnect receives global IPNService references when + // disconnecting. + onDisconnect = make(chan jnipkg.Object) + + onConnect = make(chan ConnectEvent) + + // onGoogleToken receives google ID tokens. + onGoogleToken = make(chan string) + + // onDNSConfigChanged is notified when the network changes and the DNS config needs to be updated. + onDNSConfigChanged = make(chan struct{}, 1) +) + +const ( + // Request codes for Android callbacks. + // requestSignin is for Google Sign-In. + requestSignin C.jint = 1000 + iota + // requestPrepareVPN is for when Android's VpnService.prepare + // completes. + requestPrepareVPN +) + +// resultOK is Android's Activity.RESULT_OK. +const resultOK = -1 + +//export Java_com_tailscale_ipn_App_onVPNPrepared +func Java_com_tailscale_ipn_App_onVPNPrepared(env *C.JNIEnv, class C.jclass) { + notifyVPNPrepared() +} + +//export Java_com_tailscale_ipn_IPNService_requestVPN +func Java_com_tailscale_ipn_IPNService_requestVPN(env *C.JNIEnv, this C.jobject) { + jenv := (*jnipkg.Env)(unsafe.Pointer(env)) + onVPNRequested <- jnipkg.NewGlobalRef(jenv, jnipkg.Object(this)) +} + +//export Java_com_tailscale_ipn_IPNService_connect +func Java_com_tailscale_ipn_IPNService_connect(env *C.JNIEnv, this C.jobject) { + onConnect <- ConnectEvent{Enable: true} +} + +//export Java_com_tailscale_ipn_IPNService_disconnect +func Java_com_tailscale_ipn_IPNService_disconnect(env *C.JNIEnv, this C.jobject) { + jenv := (*jnipkg.Env)(unsafe.Pointer(env)) + onDisconnect <- jnipkg.NewGlobalRef(jenv, jnipkg.Object(this)) +} + +//export Java_com_tailscale_ipn_StartVPNWorker_connect +func Java_com_tailscale_ipn_StartVPNWorker_connect(env *C.JNIEnv, this C.jobject) { + onConnect <- ConnectEvent{Enable: true} +} + +//export Java_com_tailscale_ipn_StopVPNWorker_disconnect +func Java_com_tailscale_ipn_StopVPNWorker_disconnect(env *C.JNIEnv, this C.jobject) { + onConnect <- ConnectEvent{Enable: false} +} + +//export Java_com_tailscale_ipn_Peer_onActivityResult0 +func Java_com_tailscale_ipn_Peer_onActivityResult0(env *C.JNIEnv, cls C.jclass, act C.jobject, reqCode, resCode C.jint) { + switch reqCode { + case requestSignin: + if resCode != resultOK { + onGoogleToken <- "" + break + } + jenv := (*jnipkg.Env)(unsafe.Pointer(env)) + m := jnipkg.GetStaticMethodID(jenv, googleClass, + "getIdTokenForActivity", "(Landroid/app/Activity;)Ljava/lang/String;") + idToken, err := jnipkg.CallStaticObjectMethod(jenv, googleClass, m, jnipkg.Value(act)) + if err != nil { + fatalErr(err) + break + } + tok := jnipkg.GoString(jenv, jnipkg.String(idToken)) + onGoogleToken <- tok + case requestPrepareVPN: + if resCode == resultOK { + notifyVPNPrepared() + } else { + notifyVPNClosed() + notifyVPNRevoked() + } + } +} + +//export Java_com_tailscale_ipn_App_onDnsConfigChanged +func Java_com_tailscale_ipn_App_onDnsConfigChanged(env *C.JNIEnv, cls C.jclass) { + select { + case onDNSConfigChanged <- struct{}{}: + default: + } +} + +func notifyVPNPrepared() { + select { + case onVPNPrepared <- struct{}{}: + default: + } +} + +func notifyVPNRevoked() { + select { + case onVPNRevoked <- struct{}{}: + default: + } +} + +func notifyVPNClosed() { + select { + case onVPNClosed <- struct{}{}: + default: + } +} + +var android struct { + // mu protects all fields of this structure. However, once a + // non-nil jvm is returned from javaVM, all the other fields may + // be accessed unlocked. + mu sync.Mutex + jvm *jnipkg.JVM + + // appCtx is the global Android App context. + appCtx C.jobject +} diff --git a/pkg/tailscale/log.go b/pkg/tailscale/log.go new file mode 100644 index 0000000..407c1fc --- /dev/null +++ b/pkg/tailscale/log.go @@ -0,0 +1,91 @@ +// Gratefully borrowed from Gio https://gioui.org/ +// SPDX-License-Identifier: MIT + +package main + +/* +#cgo LDFLAGS: -llog + +#include +#include +*/ +import "C" + +import ( + "bufio" + "log" + "os" + "path/filepath" + "runtime" + "syscall" + "unsafe" +) + +// 1024 is the truncation limit from android/log.h, plus a \n. +const logLineLimit = 1024 + +var ID = filepath.Base(os.Args[0]) + +var logTag = C.CString(ID) + +func init() { + // Android's logcat already includes timestamps. + log.SetFlags(log.Flags() &^ log.LstdFlags) + log.SetOutput(new(androidLogWriter)) + + // Redirect stdout and stderr to the Android logger. + logFd(os.Stdout.Fd()) + logFd(os.Stderr.Fd()) +} + +type androidLogWriter struct { + // buf has room for the maximum log line, plus a terminating '\0'. + buf [logLineLimit + 1]byte +} + +func (w *androidLogWriter) Write(data []byte) (int, error) { + n := 0 + for len(data) > 0 { + msg := data + // Truncate the buffer, leaving space for the '\0'. + if max := len(w.buf) - 1; len(msg) > max { + msg = msg[:max] + } + buf := w.buf[:len(msg)+1] + copy(buf, msg) + // Terminating '\0'. + buf[len(msg)] = 0 + C.__android_log_write(C.ANDROID_LOG_INFO, logTag, (*C.char)(unsafe.Pointer(&buf[0]))) + n += len(msg) + data = data[len(msg):] + } + return n, nil +} + +func logFd(fd uintptr) { + r, w, err := os.Pipe() + if err != nil { + panic(err) + } + if err := syscall.Dup3(int(w.Fd()), int(fd), syscall.O_CLOEXEC); err != nil { + panic(err) + } + go func() { + lineBuf := bufio.NewReaderSize(r, logLineLimit) + // The buffer to pass to C, including the terminating '\0'. + buf := make([]byte, lineBuf.Size()+1) + cbuf := (*C.char)(unsafe.Pointer(&buf[0])) + for { + line, _, err := lineBuf.ReadLine() + if err != nil { + break + } + copy(buf, line) + buf[len(line)] = 0 + C.__android_log_write(C.ANDROID_LOG_INFO, logTag, cbuf) + } + // The garbage collector doesn't know that w's fd was dup'ed. + // Avoid finalizing w, and thereby avoid its finalizer closing its fd. + runtime.KeepAlive(w) + }() +} diff --git a/pkg/tailscale/multitun.go b/pkg/tailscale/multitun.go new file mode 100644 index 0000000..4531f88 --- /dev/null +++ b/pkg/tailscale/multitun.go @@ -0,0 +1,282 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "os" + + "github.com/tailscale/wireguard-go/tun" +) + +// multiTUN implements a tun.Device that supports multiple +// underlying devices. This is necessary because Android VPN devices +// have static configurations and wgengine.NewUserspaceEngine +// assumes a single static tun.Device. +type multiTUN struct { + // devices is for adding new devices. + devices chan tun.Device + // event is the combined event channel from all active devices. + events chan tun.Event + + close chan struct{} + closeErr chan error + + reads chan ioRequest + writes chan ioRequest + mtus chan chan mtuReply + names chan chan nameReply + shutdowns chan struct{} + shutdownDone chan struct{} +} + +// tunDevice wraps and drives a single run.Device. +type tunDevice struct { + dev tun.Device + // close closes the device. + close chan struct{} + closeDone chan error + // readDone is notified when the read goroutine is done. + readDone chan struct{} +} + +type ioRequest struct { + data [][]byte + sizes []int + offset int + reply chan<- ioReply +} + +type ioReply struct { + count int + err error +} + +type mtuReply struct { + mtu int + err error +} + +type nameReply struct { + name string + err error +} + +func newTUNDevices() *multiTUN { + d := &multiTUN{ + devices: make(chan tun.Device), + events: make(chan tun.Event), + close: make(chan struct{}), + closeErr: make(chan error), + reads: make(chan ioRequest), + writes: make(chan ioRequest), + mtus: make(chan chan mtuReply), + names: make(chan chan nameReply), + shutdowns: make(chan struct{}), + shutdownDone: make(chan struct{}), + } + go d.run() + return d +} + +func (d *multiTUN) run() { + var devices []*tunDevice + // readDone is the readDone channel of the device being read from. + var readDone chan struct{} + // runDone is the closeDone channel of the device being written to. + var runDone chan error + for { + select { + case <-readDone: + // The oldest device has reached EOF, replace it. + n := copy(devices, devices[1:]) + devices = devices[:n] + if len(devices) > 0 { + // Start reading from the next device. + dev := devices[0] + readDone = dev.readDone + go d.readFrom(dev) + } + case <-runDone: + // A device completed runDevice, replace it. + if len(devices) > 0 { + dev := devices[len(devices)-1] + runDone = dev.closeDone + go d.runDevice(dev) + } + case <-d.shutdowns: + // Shut down all devices. + for _, dev := range devices { + close(dev.close) + <-dev.closeDone + <-dev.readDone + } + devices = nil + d.shutdownDone <- struct{}{} + case <-d.close: + var derr error + for _, dev := range devices { + if err := <-dev.closeDone; err != nil { + derr = err + } + } + d.closeErr <- derr + return + case dev := <-d.devices: + if len(devices) > 0 { + // Ask the most recent device to stop. + prev := devices[len(devices)-1] + close(prev.close) + } + wrap := &tunDevice{ + dev: dev, + close: make(chan struct{}), + closeDone: make(chan error), + readDone: make(chan struct{}, 1), + } + if len(devices) == 0 { + // Start using this first device. + readDone = wrap.readDone + go d.readFrom(wrap) + runDone = wrap.closeDone + go d.runDevice(wrap) + } + devices = append(devices, wrap) + case m := <-d.mtus: + r := mtuReply{mtu: defaultMTU} + if len(devices) > 0 { + dev := devices[len(devices)-1] + r.mtu, r.err = dev.dev.MTU() + } + m <- r + case n := <-d.names: + var r nameReply + if len(devices) > 0 { + dev := devices[len(devices)-1] + r.name, r.err = dev.dev.Name() + } + n <- r + } + } +} + +func (d *multiTUN) readFrom(dev *tunDevice) { + defer func() { + dev.readDone <- struct{}{} + }() + for { + select { + case r := <-d.reads: + n, err := dev.dev.Read(r.data, r.sizes, r.offset) + stop := false + if err != nil { + select { + case <-dev.close: + stop = true + err = nil + default: + } + } + r.reply <- ioReply{n, err} + if stop { + return + } + case <-d.close: + return + } + } +} + +func (d *multiTUN) runDevice(dev *tunDevice) { + defer func() { + // The documentation for https://developer.android.com/reference/android/net/VpnService.Builder#establish() + // states that "Therefore, after draining the old file + // descriptor...", but pending Reads are never unblocked + // when a new descriptor is created. + // + // Close it instead and hope that no packets are lost. + dev.closeDone <- dev.dev.Close() + }() + // Pump device events. + go func() { + for { + select { + case e := <-dev.dev.Events(): + d.events <- e + case <-dev.close: + return + } + } + }() + for { + select { + case w := <-d.writes: + n, err := dev.dev.Write(w.data, w.offset) + w.reply <- ioReply{n, err} + case <-dev.close: + // Device closed. + return + case <-d.close: + // Multi-device closed. + return + } + } +} + +func (d *multiTUN) add(dev tun.Device) { + d.devices <- dev +} + +func (d *multiTUN) File() *os.File { + // The underlying file descriptor is not constant on Android. + // Let's hope no-one uses it. + panic("not available on Android") +} + +func (d *multiTUN) Read(data [][]byte, sizes []int, offset int) (int, error) { + r := make(chan ioReply) + d.reads <- ioRequest{data, sizes, offset, r} + rep := <-r + return rep.count, rep.err +} + +func (d *multiTUN) Write(data [][]byte, offset int) (int, error) { + r := make(chan ioReply) + d.writes <- ioRequest{data, nil, offset, r} + rep := <-r + return rep.count, rep.err +} + +func (d *multiTUN) MTU() (int, error) { + r := make(chan mtuReply) + d.mtus <- r + rep := <-r + return rep.mtu, rep.err +} + +func (d *multiTUN) Name() (string, error) { + r := make(chan nameReply) + d.names <- r + rep := <-r + return rep.name, rep.err +} + +func (d *multiTUN) Events() <-chan tun.Event { + return d.events +} + +func (d *multiTUN) Shutdown() { + d.shutdowns <- struct{}{} + <-d.shutdownDone +} + +func (d *multiTUN) Close() error { + close(d.close) + return <-d.closeErr +} + +func (d *multiTUN) BatchSize() int { + // TODO(raggi): currently Android disallows the necessary ioctls to enable + // batching. File a bug. + return 1 +} diff --git a/pkg/tailscale/net.go b/pkg/tailscale/net.go new file mode 100644 index 0000000..d2b14c9 --- /dev/null +++ b/pkg/tailscale/net.go @@ -0,0 +1,346 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "errors" + "fmt" + "log" + "net" + "net/netip" + "reflect" + "strings" + + jnipkg "github.com/tailscale/tailscale-android/pkg/jni" + "github.com/tailscale/wireguard-go/tun" + "golang.org/x/sys/unix" + "inet.af/netaddr" + "tailscale.com/net/dns" + "tailscale.com/net/interfaces" + "tailscale.com/util/dnsname" + "tailscale.com/wgengine/router" +) + +import "C" + +// errVPNNotPrepared is used when VPNService.Builder.establish returns +// null, either because the VPNService is not yet prepared or because +// VPN status was revoked. +var errVPNNotPrepared = errors.New("VPN service not prepared or was revoked") + +// errMultipleUsers is used when we get a "INTERACT_ACROSS_USERS" error, which +// happens due to a bug in Android. See: +// +// https://github.com/tailscale/tailscale/issues/2180 +var errMultipleUsers = errors.New("VPN cannot be created on this device due to an Android bug with multiple users") + +// Report interfaces in the device in net.Interface format. +func (a *App) getInterfaces() ([]interfaces.Interface, error) { + var ifaceString string + err := jnipkg.Do(a.jvm, func(env *jnipkg.Env) error { + cls := jnipkg.GetObjectClass(env, a.appCtx) + m := jnipkg.GetMethodID(env, cls, "getInterfacesAsString", "()Ljava/lang/String;") + n, err := jnipkg.CallObjectMethod(env, a.appCtx, m) + ifaceString = jnipkg.GoString(env, jnipkg.String(n)) + return err + + }) + var ifaces []interfaces.Interface + if err != nil { + return ifaces, err + } + + for _, iface := range strings.Split(ifaceString, "\n") { + // Example of the strings we're processing: + // wlan0 30 1500 true true false false true | fe80::2f60:2c82:4163:8389%wlan0/64 10.1.10.131/24 + // r_rmnet_data0 21 1500 true false false false false | fe80::9318:6093:d1ad:ba7f%r_rmnet_data0/64 + // mnet_data2 12 1500 true false false false false | fe80::3c8c:44dc:46a9:9907%rmnet_data2/64 + + if strings.TrimSpace(iface) == "" { + continue + } + + fields := strings.Split(iface, "|") + if len(fields) != 2 { + log.Printf("getInterfaces: unable to split %q", iface) + continue + } + + var name string + var index, mtu int + var up, broadcast, loopback, pointToPoint, multicast bool + _, err := fmt.Sscanf(fields[0], "%s %d %d %t %t %t %t %t", + &name, &index, &mtu, &up, &broadcast, &loopback, &pointToPoint, &multicast) + if err != nil { + log.Printf("getInterfaces: unable to parse %q: %v", iface, err) + continue + } + + newIf := interfaces.Interface{ + Interface: &net.Interface{ + Name: name, + Index: index, + MTU: mtu, + }, + AltAddrs: []net.Addr{}, // non-nil to avoid Go using netlink + } + if up { + newIf.Flags |= net.FlagUp + } + if broadcast { + newIf.Flags |= net.FlagBroadcast + } + if loopback { + newIf.Flags |= net.FlagLoopback + } + if pointToPoint { + newIf.Flags |= net.FlagPointToPoint + } + if multicast { + newIf.Flags |= net.FlagMulticast + } + + addrs := strings.Trim(fields[1], " \n") + for _, addr := range strings.Split(addrs, " ") { + ip, err := netaddr.ParseIPPrefix(addr) + if err == nil { + newIf.AltAddrs = append(newIf.AltAddrs, ip.IPNet()) + } + } + + ifaces = append(ifaces, newIf) + } + + return ifaces, nil +} + +// googleDNSServers are used on ChromeOS, where an empty VpnBuilder DNS setting results +// in erasing the platform DNS servers. The developer docs say this is not supposed to happen, +// but nonetheless it does. +var googleDNSServers = []netip.Addr{ + netip.MustParseAddr("8.8.8.8"), + netip.MustParseAddr("8.8.4.4"), + netip.MustParseAddr("2001:4860:4860::8888"), + netip.MustParseAddr("2001:4860:4860::8844"), +} + +func (b *backend) updateTUN(service jnipkg.Object, rcfg *router.Config, dcfg *dns.OSConfig) error { + if reflect.DeepEqual(rcfg, b.lastCfg) && reflect.DeepEqual(dcfg, b.lastDNSCfg) { + return nil + } + + // Close previous tunnel(s). + // This is necessary for ChromeOS, native Android devices + // seem to handle seamless handover between tunnels correctly. + // + // TODO(eliasnaur): If seamless handover becomes a desirable feature, skip + // the closing on ChromeOS. + b.CloseTUNs() + + if len(rcfg.LocalAddrs) == 0 { + return nil + } + err := jnipkg.Do(b.jvm, func(env *jnipkg.Env) error { + cls := jnipkg.GetObjectClass(env, service) + // Construct a VPNService.Builder. IPNService.newBuilder calls + // setConfigureIntent, and allowFamily for both IPv4 and IPv6. + m := jnipkg.GetMethodID(env, cls, "newBuilder", "()Landroid/net/VpnService$Builder;") + builder, err := jnipkg.CallObjectMethod(env, service, m) + if err != nil { + return fmt.Errorf("IPNService.newBuilder: %v", err) + } + bcls := jnipkg.GetObjectClass(env, builder) + + // builder.setMtu. + setMtu := jnipkg.GetMethodID(env, bcls, "setMtu", "(I)Landroid/net/VpnService$Builder;") + const mtu = defaultMTU + if _, err := jnipkg.CallObjectMethod(env, builder, setMtu, jnipkg.Value(mtu)); err != nil { + return fmt.Errorf("VpnService.Builder.setMtu: %v", err) + } + + // builder.addDnsServer + addDnsServer := jnipkg.GetMethodID(env, bcls, "addDnsServer", "(Ljava/lang/String;)Landroid/net/VpnService$Builder;") + // builder.addSearchDomain. + addSearchDomain := jnipkg.GetMethodID(env, bcls, "addSearchDomain", "(Ljava/lang/String;)Landroid/net/VpnService$Builder;") + if dcfg != nil { + nameservers := dcfg.Nameservers + if b.avoidEmptyDNS && len(nameservers) == 0 { + nameservers = googleDNSServers + } + for _, dns := range nameservers { + _, err = jnipkg.CallObjectMethod(env, + builder, + addDnsServer, + jnipkg.Value(jnipkg.JavaString(env, dns.String())), + ) + if err != nil { + return fmt.Errorf("VpnService.Builder.addDnsServer(%v): %v", dns, err) + } + } + + for _, dom := range dcfg.SearchDomains { + _, err = jnipkg.CallObjectMethod(env, + builder, + addSearchDomain, + jnipkg.Value(jnipkg.JavaString(env, dom.WithoutTrailingDot())), + ) + if err != nil { + return fmt.Errorf("VpnService.Builder.addSearchDomain(%v): %v", dom, err) + } + } + } + + // builder.addRoute. + addRoute := jnipkg.GetMethodID(env, bcls, "addRoute", "(Ljava/lang/String;I)Landroid/net/VpnService$Builder;") + for _, route := range rcfg.Routes { + // Normalize route address; Builder.addRoute does not accept non-zero masked bits. + route = route.Masked() + _, err = jnipkg.CallObjectMethod(env, + builder, + addRoute, + jnipkg.Value(jnipkg.JavaString(env, route.Addr().String())), + jnipkg.Value(route.Bits()), + ) + if err != nil { + return fmt.Errorf("VpnService.Builder.addRoute(%v): %v", route, err) + } + } + + // builder.addAddress. + addAddress := jnipkg.GetMethodID(env, bcls, "addAddress", "(Ljava/lang/String;I)Landroid/net/VpnService$Builder;") + for _, addr := range rcfg.LocalAddrs { + _, err = jnipkg.CallObjectMethod(env, + builder, + addAddress, + jnipkg.Value(jnipkg.JavaString(env, addr.Addr().String())), + jnipkg.Value(addr.Bits()), + ) + if err != nil { + return fmt.Errorf("VpnService.Builder.addAddress(%v): %v", addr, err) + } + } + + // builder.establish. + establish := jnipkg.GetMethodID(env, bcls, "establish", "()Landroid/os/ParcelFileDescriptor;") + parcelFD, err := jnipkg.CallObjectMethod(env, builder, establish) + if err != nil { + if strings.Contains(err.Error(), "INTERACT_ACROSS_USERS") { + return errMultipleUsers + } + return fmt.Errorf("VpnService.Builder.establish: %v", err) + } + if parcelFD == 0 { + return errVPNNotPrepared + } + + // detachFd. + parcelCls := jnipkg.GetObjectClass(env, parcelFD) + detachFd := jnipkg.GetMethodID(env, parcelCls, "detachFd", "()I") + tunFD, err := jnipkg.CallIntMethod(env, parcelFD, detachFd) + if err != nil { + return fmt.Errorf("detachFd: %v", err) + } + + // Create TUN device. + tunDev, _, err := tun.CreateUnmonitoredTUNFromFD(int(tunFD)) + if err != nil { + unix.Close(int(tunFD)) + return err + } + + b.devices.add(tunDev) + + return nil + }) + if err != nil { + b.lastCfg = nil + b.CloseTUNs() + return err + } + b.lastCfg = rcfg + b.lastDNSCfg = dcfg + return nil +} + +// CloseVPN closes any active TUN devices. +func (b *backend) CloseTUNs() { + b.lastCfg = nil + b.devices.Shutdown() +} + +func (b *backend) NetworkChanged() { + if b.sys != nil { + if nm, ok := b.sys.NetMon.GetOK(); ok { + nm.InjectEvent() + } + } +} + +func (b *backend) getDNSBaseConfig() (ret dns.OSConfig, _ error) { + defer func() { + // If we couldn't find any base nameservers, ultimately fall back to + // Google's. Normally Tailscale doesn't ever pick a default nameserver + // for users but in this case Android's APIs for reading the underlying + // DNS config are lacking, and almost all Android phones use Google + // services anyway, so it's a reasonable default: it's an ecosystem the + // user has selected by having an Android device. + if len(ret.Nameservers) == 0 && googleSignInEnabled() { + log.Printf("getDNSBaseConfig: none found; falling back to Google public DNS") + ret.Nameservers = append(ret.Nameservers, googleDNSServers...) + } + }() + baseConfig := b.getPlatformDNSConfig() + lines := strings.Split(baseConfig, "\n") + if len(lines) == 0 { + return dns.OSConfig{}, nil + } + + config := dns.OSConfig{} + addrs := strings.Trim(lines[0], " \n") + for _, addr := range strings.Split(addrs, " ") { + ip, err := netip.ParseAddr(addr) + if err == nil { + config.Nameservers = append(config.Nameservers, ip) + } + } + + if len(lines) > 1 { + for _, s := range strings.Split(strings.Trim(lines[1], " \n"), " ") { + domain, err := dnsname.ToFQDN(s) + if err != nil { + log.Printf("getDNSBaseConfig: unable to parse %q: %v", s, err) + continue + } + config.SearchDomains = append(config.SearchDomains, domain) + } + } + + return config, nil +} + +func (b *backend) getPlatformDNSConfig() string { + var baseConfig string + err := jnipkg.Do(b.jvm, func(env *jnipkg.Env) error { + cls := jnipkg.GetObjectClass(env, b.appCtx) + m := jnipkg.GetMethodID(env, cls, "getDnsConfigObj", "()Lcom/tailscale/ipn/DnsConfig;") + dns, err := jnipkg.CallObjectMethod(env, b.appCtx, m) + if err != nil { + return fmt.Errorf("getDnsConfigObj: %v", err) + } + dnsCls := jnipkg.GetObjectClass(env, dns) + m = jnipkg.GetMethodID(env, dnsCls, "getDnsConfigAsString", "()Ljava/lang/String;") + n, err := jnipkg.CallObjectMethod(env, dns, m) + baseConfig = jnipkg.GoString(env, jnipkg.String(n)) + return err + }) + if err != nil { + log.Printf("getPlatformDNSConfig JNI: %v", err) + return "" + } + return baseConfig +} + +func (b *backend) setCfg(rcfg *router.Config, dcfg *dns.OSConfig) error { + return b.settings(rcfg, dcfg) +} diff --git a/pkg/tailscale/store.go b/pkg/tailscale/store.go new file mode 100644 index 0000000..1bc6b43 --- /dev/null +++ b/pkg/tailscale/store.go @@ -0,0 +1,133 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "encoding/base64" + + "tailscale.com/ipn" + + jnipkg "github.com/tailscale/tailscale-android/pkg/jni" +) + +// stateStore is the Go interface for a persistent storage +// backend by androidx.security.crypto.EncryptedSharedPreferences (see +// App.java). +type stateStore struct { + jvm *jnipkg.JVM + // appCtx is the global Android app context. + appCtx jnipkg.Object + + // Cached method ids on appCtx. + encrypt jnipkg.MethodID + decrypt jnipkg.MethodID +} + +func newStateStore(jvm *jnipkg.JVM, appCtx jnipkg.Object) *stateStore { + s := &stateStore{ + jvm: jvm, + appCtx: appCtx, + } + jnipkg.Do(jvm, func(env *jnipkg.Env) error { + appCls := jnipkg.GetObjectClass(env, appCtx) + s.encrypt = jnipkg.GetMethodID( + env, appCls, + "encryptToPref", "(Ljava/lang/String;Ljava/lang/String;)V", + ) + s.decrypt = jnipkg.GetMethodID( + env, appCls, + "decryptFromPref", "(Ljava/lang/String;)Ljava/lang/String;", + ) + return nil + }) + return s +} + +func prefKeyFor(id ipn.StateKey) string { + return "statestore-" + string(id) +} + +func (s *stateStore) ReadString(key string, def string) (string, error) { + data, err := s.read(key) + if err != nil { + return def, err + } + if data == nil { + return def, nil + } + return string(data), nil +} + +func (s *stateStore) WriteString(key string, val string) error { + return s.write(key, []byte(val)) +} + +func (s *stateStore) ReadBool(key string, def bool) (bool, error) { + data, err := s.read(key) + if err != nil { + return def, err + } + if data == nil { + return def, nil + } + return string(data) == "true", nil +} + +func (s *stateStore) WriteBool(key string, val bool) error { + data := []byte("false") + if val { + data = []byte("true") + } + return s.write(key, data) +} + +func (s *stateStore) ReadState(id ipn.StateKey) ([]byte, error) { + state, err := s.read(prefKeyFor(id)) + if err != nil { + return nil, err + } + if state == nil { + return nil, ipn.ErrStateNotExist + } + return state, nil +} + +func (s *stateStore) WriteState(id ipn.StateKey, bs []byte) error { + prefKey := prefKeyFor(id) + return s.write(prefKey, bs) +} + +func (s *stateStore) read(key string) ([]byte, error) { + var data []byte + err := jnipkg.Do(s.jvm, func(env *jnipkg.Env) error { + jfile := jnipkg.JavaString(env, key) + plain, err := jnipkg.CallObjectMethod(env, s.appCtx, s.decrypt, + jnipkg.Value(jfile)) + if err != nil { + return err + } + b64 := jnipkg.GoString(env, jnipkg.String(plain)) + if b64 == "" { + return nil + } + data, err = base64.RawStdEncoding.DecodeString(b64) + return err + }) + return data, err +} + +func (s *stateStore) write(key string, value []byte) error { + bs64 := base64.RawStdEncoding.EncodeToString(value) + err := jnipkg.Do(s.jvm, func(env *jnipkg.Env) error { + jfile := jnipkg.JavaString(env, key) + jplain := jnipkg.JavaString(env, bs64) + err := jnipkg.CallVoidMethod(env, s.appCtx, s.encrypt, + jnipkg.Value(jfile), jnipkg.Value(jplain)) + if err != nil { + return err + } + return nil + }) + return err +} diff --git a/pkg/tailscale/tailscale.go b/pkg/tailscale/tailscale.go new file mode 100644 index 0000000..17d8fe4 --- /dev/null +++ b/pkg/tailscale/tailscale.go @@ -0,0 +1,183 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "context" + "log" + "net/http" + "path/filepath" + "time" + "unsafe" + + jnipkg "github.com/tailscale/tailscale-android/pkg/jni" + "tailscale.com/logpolicy" + "tailscale.com/logtail" + "tailscale.com/logtail/filch" + "tailscale.com/net/interfaces" + "tailscale.com/smallzstd" + "tailscale.com/types/logger" + "tailscale.com/types/logid" + "tailscale.com/util/clientmetric" + "tailscale.com/util/must" +) + +import "C" + +var ( + // googleClass is a global reference to the com.tailscale.ipn.Google class. + googleClass jnipkg.Class +) + +const defaultMTU = 1280 // minimalMTU from wgengine/userspace.go + +const ( + logPrefKey = "privatelogid" + loginMethodPrefKey = "loginmethod" + customLoginServerPrefKey = "customloginserver" +) + +type ConnectEvent struct { + Enable bool +} + +func main() { + a := &App{ + jvm: (*jnipkg.JVM)(unsafe.Pointer(javaVM())), + appCtx: jnipkg.Object(appContext()), + } + + err := a.loadJNIGlobalClassRefs() + if err != nil { + fatalErr(err) + } + + a.store = newStateStore(a.jvm, a.appCtx) + interfaces.RegisterInterfaceGetter(a.getInterfaces) + go func() { + ctx := context.Background() + if err := a.runBackend(ctx); err != nil { + fatalErr(err) + } + }() +} + +func fatalErr(err error) { + // TODO: expose in UI. + log.Printf("fatal error: %v", err) +} + +// osVersion returns android.os.Build.VERSION.RELEASE. " [nogoogle]" is appended +// if Google Play services are not compiled in. +func (a *App) osVersion() string { + var version string + err := jnipkg.Do(a.jvm, func(env *jnipkg.Env) error { + cls := jnipkg.GetObjectClass(env, a.appCtx) + m := jnipkg.GetMethodID(env, cls, "getOSVersion", "()Ljava/lang/String;") + n, err := jnipkg.CallObjectMethod(env, a.appCtx, m) + version = jnipkg.GoString(env, jnipkg.String(n)) + return err + }) + if err != nil { + panic(err) + } + return version +} + +// modelName return the MANUFACTURER + MODEL from +// android.os.Build. +func (a *App) modelName() string { + var model string + err := jnipkg.Do(a.jvm, func(env *jnipkg.Env) error { + cls := jnipkg.GetObjectClass(env, a.appCtx) + m := jnipkg.GetMethodID(env, cls, "getModelName", "()Ljava/lang/String;") + n, err := jnipkg.CallObjectMethod(env, a.appCtx, m) + model = jnipkg.GoString(env, jnipkg.String(n)) + return err + }) + if err != nil { + panic(err) + } + return model +} + +func (a *App) isChromeOS() bool { + var chromeOS bool + err := jnipkg.Do(a.jvm, func(env *jnipkg.Env) error { + cls := jnipkg.GetObjectClass(env, a.appCtx) + m := jnipkg.GetMethodID(env, cls, "isChromeOS", "()Z") + b, err := jnipkg.CallBooleanMethod(env, a.appCtx, m) + chromeOS = b + return err + }) + if err != nil { + panic(err) + } + return chromeOS +} + +func googleSignInEnabled() bool { + return googleClass != 0 +} + +// Loads the global JNI class references. Failures here are fatal if the +// class ref is required for the app to function. +func (a *App) loadJNIGlobalClassRefs() error { + return jnipkg.Do(a.jvm, func(env *jnipkg.Env) error { + loader := jnipkg.ClassLoaderFor(env, a.appCtx) + cl, err := jnipkg.LoadClass(env, loader, "com.tailscale.ipn.Google") + if err != nil { + // Ignore load errors; the Google class is not included in F-Droid builds. + return nil + } + googleClass = jnipkg.Class(jnipkg.NewGlobalRef(env, jnipkg.Object(cl))) + return nil + }) +} + +// SetupLogs sets up remote logging. +func (b *backend) setupLogs(logDir string, logID logid.PrivateID, logf logger.Logf) { + if b.netMon == nil { + panic("netMon must be created prior to SetupLogs") + } + transport := logpolicy.NewLogtailTransport(logtail.DefaultHost, b.netMon, log.Printf) + + logcfg := logtail.Config{ + Collection: logtail.CollectionNode, + PrivateID: logID, + Stderr: log.Writer(), + MetricsDelta: clientmetric.EncodeLogTailMetricsDelta, + IncludeProcID: true, + IncludeProcSequence: true, + NewZstdEncoder: func() logtail.Encoder { + return must.Get(smallzstd.NewEncoder(nil)) + }, + HTTPC: &http.Client{Transport: transport}, + } + logcfg.FlushDelayFn = func() time.Duration { return 2 * time.Minute } + + filchOpts := filch.Options{ + ReplaceStderr: true, + } + + var filchErr error + if logDir != "" { + logPath := filepath.Join(logDir, "ipn.log.") + logcfg.Buffer, filchErr = filch.New(logPath, filchOpts) + } + + b.logger = logtail.NewLogger(logcfg, logf) + + log.SetFlags(0) + log.SetOutput(b.logger) + + log.Printf("goSetupLogs: success") + + if logDir == "" { + log.Printf("SetupLogs: no logDir, storing logs in memory") + } + if filchErr != nil { + log.Printf("SetupLogs: filch setup failed: %v", filchErr) + } +}