jni,cmd/tailscale: replace jni.EnvFor with explicit conversion

The EnvFor converted an uintptr to a pointer value, which is not
guaranteed to work in general. This change removes EnvFor and pushes the
potentially unsafe conversion to users of the jni package.

Fixes tailscale/tailscale#1195

Signed-off-by: Elias Naur <mail@eliasnaur.com>
pull/6/head
Elias Naur 3 years ago
parent 61d9733b24
commit ba38a9bb59

@ -157,7 +157,7 @@ func (b *backend) updateTUN(service jni.Object, cfg *router.Config) error {
if len(cfg.LocalAddrs) == 0 {
return nil
}
err := jni.Do(b.jvm, func(env jni.Env) error {
err := jni.Do(b.jvm, func(env *jni.Env) error {
cls := jni.GetObjectClass(env, service)
// Construct a VPNService.Builder. IPNService.newBuilder calls
// setConfigureIntent, and allowFamily for both IPv4 and IPv6.

@ -87,13 +87,13 @@ func notifyVPNClosed() {
//export Java_com_tailscale_ipn_IPNService_connect
func Java_com_tailscale_ipn_IPNService_connect(env *C.JNIEnv, this C.jobject) {
jenv := jni.EnvFor(uintptr(unsafe.Pointer(env)))
jenv := (*jni.Env)(unsafe.Pointer(env))
onConnect <- jni.NewGlobalRef(jenv, jni.Object(this))
}
//export Java_com_tailscale_ipn_IPNService_disconnect
func Java_com_tailscale_ipn_IPNService_disconnect(env *C.JNIEnv, this C.jobject) {
jenv := jni.EnvFor(uintptr(unsafe.Pointer(env)))
jenv := (*jni.Env)(unsafe.Pointer(env))
onDisconnect <- jni.NewGlobalRef(jenv, jni.Object(this))
}
@ -119,7 +119,7 @@ func Java_com_tailscale_ipn_Peer_onActivityResult0(env *C.JNIEnv, cls C.jclass,
onGoogleToken <- ""
break
}
jenv := jni.EnvFor(uintptr(unsafe.Pointer(env)))
jenv := (*jni.Env)(unsafe.Pointer(env))
m := jni.GetStaticMethodID(jenv, googleClass,
"getIdTokenForActivity", "(Landroid/app/Activity;)Ljava/lang/String;")
idToken, err := jni.CallStaticObjectMethod(jenv, googleClass, m, jni.Value(act))

@ -121,7 +121,7 @@ func main() {
appCtx: jni.Object(app.AppContext()),
updates: make(chan struct{}, 1),
}
err := jni.Do(a.jvm, func(env jni.Env) error {
err := jni.Do(a.jvm, func(env *jni.Env) error {
loader := jni.ClassLoaderFor(env, a.appCtx)
cl, err := jni.LoadClass(env, loader, "com.tailscale.ipn.Google")
if err != nil {
@ -282,7 +282,7 @@ func (a *App) runBackend() error {
go b.backend.SetPrefs(prefs)
}
case s := <-onConnect:
jni.Do(a.jvm, func(env jni.Env) error {
jni.Do(a.jvm, func(env *jni.Env) error {
if jni.IsSameObject(env, s, service) {
// We already have a reference.
jni.DeleteGlobalRef(env, s)
@ -312,7 +312,7 @@ func (a *App) runBackend() error {
a.notify(state)
case s := <-onDisconnect:
b.CloseTUNs()
jni.Do(a.jvm, func(env jni.Env) error {
jni.Do(a.jvm, func(env *jni.Env) error {
defer jni.DeleteGlobalRef(env, s)
if jni.IsSameObject(env, service, s) {
jni.DeleteGlobalRef(env, service)
@ -329,7 +329,7 @@ func (a *App) runBackend() error {
func (a *App) isChromeOS() bool {
var chromeOS bool
err := jni.Do(a.jvm, func(env jni.Env) error {
err := jni.Do(a.jvm, func(env *jni.Env) error {
cls := jni.GetObjectClass(env, a.appCtx)
m := jni.GetMethodID(env, cls, "isChromeOS", "()Z")
b, err := jni.CallBooleanMethod(env, a.appCtx, m)
@ -346,7 +346,7 @@ func (a *App) isChromeOS() bool {
// useless os.Hostname().
func (a *App) hostname() string {
var hostname string
err := jni.Do(a.jvm, func(env jni.Env) error {
err := jni.Do(a.jvm, func(env *jni.Env) error {
cls := jni.GetObjectClass(env, a.appCtx)
getHostname := jni.GetMethodID(env, cls, "getHostname", "()Ljava/lang/String;")
n, err := jni.CallObjectMethod(env, a.appCtx, getHostname)
@ -363,7 +363,7 @@ func (a *App) hostname() string {
// if Google Play services are not compiled in.
func (a *App) osVersion() string {
var version string
err := jni.Do(a.jvm, func(env jni.Env) error {
err := jni.Do(a.jvm, func(env *jni.Env) error {
cls := jni.GetObjectClass(env, a.appCtx)
m := jni.GetMethodID(env, cls, "getOSVersion", "()Ljava/lang/String;")
n, err := jni.CallObjectMethod(env, a.appCtx, m)
@ -383,7 +383,7 @@ func (a *App) osVersion() string {
// android.os.Build.
func (a *App) modelName() string {
var model string
err := jni.Do(a.jvm, func(env jni.Env) error {
err := jni.Do(a.jvm, func(env *jni.Env) error {
cls := jni.GetObjectClass(env, a.appCtx)
m := jni.GetMethodID(env, cls, "getModelName", "()Ljava/lang/String;")
n, err := jni.CallObjectMethod(env, a.appCtx, m)
@ -411,7 +411,7 @@ func (a *App) updateNotification(service jni.Object, state ipn.State) error {
default:
return nil
}
return jni.Do(a.jvm, func(env jni.Env) error {
return jni.Do(a.jvm, func(env *jni.Env) error {
cls := jni.GetObjectClass(env, service)
update := jni.GetMethodID(env, cls, "updateStatusNotification", "(Ljava/lang/String;Ljava/lang/String;)V")
jtitle := jni.JavaString(env, title)
@ -447,7 +447,7 @@ func (a *App) notifyExpiry(service jni.Object, expiry time.Time) *time.Timer {
default:
return time.NewTimer(d - aday)
}
err := jni.Do(a.jvm, func(env jni.Env) error {
err := jni.Do(a.jvm, func(env *jni.Env) error {
cls := jni.GetObjectClass(env, service)
notify := jni.GetMethodID(env, cls, "notify", "(Ljava/lang/String;Ljava/lang/String;)V")
jtitle := jni.JavaString(env, title)
@ -514,7 +514,7 @@ func (a *App) runUI() error {
if activity == 0 {
return
}
jni.Do(a.jvm, func(env jni.Env) error {
jni.Do(a.jvm, func(env *jni.Env) error {
jni.DeleteGlobalRef(env, activity)
return nil
})
@ -613,7 +613,7 @@ func (a *App) runUI() error {
// signature.
func (a *App) isReleaseSigned() bool {
var cert []byte
err := jni.Do(a.jvm, func(env jni.Env) error {
err := jni.Do(a.jvm, func(env *jni.Env) error {
cls := jni.GetObjectClass(env, a.appCtx)
m := jni.GetMethodID(env, cls, "getPackageCertificate", "()[B")
str, err := jni.CallObjectMethod(env, a.appCtx, m)
@ -756,7 +756,7 @@ func (a *App) signOut() {
if googleClass == 0 {
return
}
err := jni.Do(a.jvm, func(env jni.Env) error {
err := jni.Do(a.jvm, func(env *jni.Env) error {
m := jni.GetStaticMethodID(env, googleClass,
"googleSignOut", "(Landroid/content/Context;)V")
return jni.CallStaticVoidMethod(env, googleClass, m, jni.Value(a.appCtx))
@ -770,7 +770,7 @@ func (a *App) googleSignIn(act jni.Object) {
if act == 0 || googleClass == 0 {
return
}
err := jni.Do(a.jvm, func(env jni.Env) error {
err := jni.Do(a.jvm, func(env *jni.Env) error {
sid := jni.JavaString(env, serverOAuthID)
m := jni.GetStaticMethodID(env, googleClass,
"googleSignIn", "(Landroid/app/Activity;Ljava/lang/String;I)V")
@ -786,7 +786,7 @@ func (a *App) browseToURL(act jni.Object, url string) {
if act == 0 {
return
}
err := jni.Do(a.jvm, func(env jni.Env) error {
err := jni.Do(a.jvm, func(env *jni.Env) error {
jurl := jni.JavaString(env, url)
return a.callVoidMethod(a.appCtx, "showURL", "(Landroid/app/Activity;Ljava/lang/String;)V", jni.Value(act), jni.Value(jurl))
})
@ -799,7 +799,7 @@ func (a *App) callVoidMethod(obj jni.Object, name, sig string, args ...jni.Value
if obj == 0 {
panic("invalid object")
}
return jni.Do(a.jvm, func(env jni.Env) error {
return jni.Do(a.jvm, func(env *jni.Env) error {
cls := jni.GetObjectClass(env, obj)
m := jni.GetMethodID(env, cls, name, sig)
return jni.CallVoidMethod(env, obj, m, args...)
@ -813,7 +813,7 @@ func (a *App) contextForView(view jni.Object) jni.Object {
panic("invalid object")
}
var ctx jni.Object
err := jni.Do(a.jvm, func(env jni.Env) error {
err := jni.Do(a.jvm, func(env *jni.Env) error {
cls := jni.GetObjectClass(env, view)
m := jni.GetMethodID(env, cls, "getContext", "()Landroid/content/Context;")
var err error

@ -30,7 +30,7 @@ func newStateStore(jvm *jni.JVM, appCtx jni.Object) *stateStore {
jvm: jvm,
appCtx: appCtx,
}
jni.Do(jvm, func(env jni.Env) error {
jni.Do(jvm, func(env *jni.Env) error {
appCls := jni.GetObjectClass(env, appCtx)
s.encrypt = jni.GetMethodID(
env, appCls,
@ -101,7 +101,7 @@ func (s *stateStore) WriteState(id ipn.StateKey, bs []byte) error {
func (s *stateStore) read(key string) ([]byte, error) {
var data []byte
err := jni.Do(s.jvm, func(env jni.Env) error {
err := jni.Do(s.jvm, func(env *jni.Env) error {
jfile := jni.JavaString(env, key)
plain, err := jni.CallObjectMethod(env, s.appCtx, s.decrypt,
jni.Value(jfile))
@ -120,7 +120,7 @@ func (s *stateStore) read(key string) ([]byte, error) {
func (s *stateStore) write(key string, value []byte) error {
bs64 := base64.RawStdEncoding.EncodeToString(value)
err := jni.Do(s.jvm, func(env jni.Env) error {
err := jni.Do(s.jvm, func(env *jni.Env) error {
jfile := jni.JavaString(env, key)
jplain := jni.JavaString(env, bs64)
err := jni.CallVoidMethod(env, s.appCtx, s.encrypt,

@ -52,9 +52,7 @@ import "C"
type JVM C.JavaVM
type Env struct {
env *C.JNIEnv
}
type Env C.JNIEnv
type (
Class C.jclass
@ -71,10 +69,8 @@ const (
False Boolean = C.JNI_FALSE
)
func EnvFor(envPtr uintptr) Env {
return Env{
env: (*C.JNIEnv)(unsafe.Pointer(envPtr)),
}
func env(e *Env) *C.JNIEnv {
return (*C.JNIEnv)(unsafe.Pointer(e))
}
func javavm(vm *JVM) *C.JavaVM {
@ -83,7 +79,7 @@ func javavm(vm *JVM) *C.JavaVM {
// Do invokes a function with a temporary JVM environment. The
// environment is not valid after the function returns.
func Do(vm *JVM, f func(env Env) error) error {
func Do(vm *JVM, f func(env *Env) error) error {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
var env *C.JNIEnv
@ -97,7 +93,7 @@ func Do(vm *JVM, f func(env Env) error) error {
defer C._jni_DetachCurrentThread(javavm(vm))
}
return f(Env{env})
return f((*Env)(unsafe.Pointer(env)))
}
func varArgs(args []Value) *C.jvalue {
@ -107,54 +103,54 @@ func varArgs(args []Value) *C.jvalue {
return (*C.jvalue)(unsafe.Pointer(&args[0]))
}
func IsSameObject(e Env, ref1, ref2 Object) bool {
same := C._jni_IsSameObject(e.env, C.jobject(ref1), C.jobject(ref2))
func IsSameObject(e *Env, ref1, ref2 Object) bool {
same := C._jni_IsSameObject(env(e), C.jobject(ref1), C.jobject(ref2))
return same == C.JNI_TRUE
}
func CallStaticIntMethod(e Env, cls Class, method MethodID, args ...Value) (int, error) {
res := C._jni_CallStaticIntMethodA(e.env, C.jclass(cls), C.jmethodID(method), varArgs(args))
func CallStaticIntMethod(e *Env, cls Class, method MethodID, args ...Value) (int, error) {
res := C._jni_CallStaticIntMethodA(env(e), C.jclass(cls), C.jmethodID(method), varArgs(args))
return int(res), exception(e)
}
func CallStaticVoidMethod(e Env, cls Class, method MethodID, args ...Value) error {
C._jni_CallStaticVoidMethodA(e.env, C.jclass(cls), C.jmethodID(method), varArgs(args))
func CallStaticVoidMethod(e *Env, cls Class, method MethodID, args ...Value) error {
C._jni_CallStaticVoidMethodA(env(e), C.jclass(cls), C.jmethodID(method), varArgs(args))
return exception(e)
}
func CallVoidMethod(e Env, obj Object, method MethodID, args ...Value) error {
C._jni_CallVoidMethodA(e.env, C.jobject(obj), C.jmethodID(method), varArgs(args))
func CallVoidMethod(e *Env, obj Object, method MethodID, args ...Value) error {
C._jni_CallVoidMethodA(env(e), C.jobject(obj), C.jmethodID(method), varArgs(args))
return exception(e)
}
func CallStaticObjectMethod(e Env, cls Class, method MethodID, args ...Value) (Object, error) {
res := C._jni_CallStaticObjectMethodA(e.env, C.jclass(cls), C.jmethodID(method), varArgs(args))
func CallStaticObjectMethod(e *Env, cls Class, method MethodID, args ...Value) (Object, error) {
res := C._jni_CallStaticObjectMethodA(env(e), C.jclass(cls), C.jmethodID(method), varArgs(args))
return Object(res), exception(e)
}
func CallObjectMethod(e Env, obj Object, method MethodID, args ...Value) (Object, error) {
res := C._jni_CallObjectMethodA(e.env, C.jobject(obj), C.jmethodID(method), varArgs(args))
func CallObjectMethod(e *Env, obj Object, method MethodID, args ...Value) (Object, error) {
res := C._jni_CallObjectMethodA(env(e), C.jobject(obj), C.jmethodID(method), varArgs(args))
return Object(res), exception(e)
}
func CallBooleanMethod(e Env, obj Object, method MethodID, args ...Value) (bool, error) {
res := C._jni_CallBooleanMethodA(e.env, C.jobject(obj), C.jmethodID(method), varArgs(args))
func CallBooleanMethod(e *Env, obj Object, method MethodID, args ...Value) (bool, error) {
res := C._jni_CallBooleanMethodA(env(e), C.jobject(obj), C.jmethodID(method), varArgs(args))
return res == C.JNI_TRUE, exception(e)
}
func CallIntMethod(e Env, obj Object, method MethodID, args ...Value) (int32, error) {
res := C._jni_CallIntMethodA(e.env, C.jobject(obj), C.jmethodID(method), varArgs(args))
func CallIntMethod(e *Env, obj Object, method MethodID, args ...Value) (int32, error) {
res := C._jni_CallIntMethodA(env(e), C.jobject(obj), C.jmethodID(method), varArgs(args))
return int32(res), exception(e)
}
// GetByteArrayElements returns the contents of the array.
func GetByteArrayElements(e Env, jarr ByteArray) []byte {
func GetByteArrayElements(e *Env, jarr ByteArray) []byte {
if jarr == 0 {
return nil
}
size := C._jni_GetArrayLength(e.env, C.jarray(jarr))
elems := C._jni_GetByteArrayElements(e.env, C.jbyteArray(jarr))
defer C._jni_ReleaseByteArrayElements(e.env, C.jbyteArray(jarr), elems, 0)
size := C._jni_GetArrayLength(env(e), C.jarray(jarr))
elems := C._jni_GetByteArrayElements(env(e), C.jbyteArray(jarr))
defer C._jni_ReleaseByteArrayElements(env(e), C.jbyteArray(jarr), elems, 0)
backing := (*(*[1 << 30]byte)(unsafe.Pointer(elems)))[:size:size]
s := make([]byte, len(backing))
copy(s, backing)
@ -163,13 +159,13 @@ func GetByteArrayElements(e Env, jarr ByteArray) []byte {
// NewByteArray allocates a Java byte array with the content. It
// panics if the allocation fails.
func NewByteArray(e Env, content []byte) ByteArray {
jarr := C._jni_NewByteArray(e.env, C.jsize(len(content)))
func NewByteArray(e *Env, content []byte) ByteArray {
jarr := C._jni_NewByteArray(env(e), C.jsize(len(content)))
if jarr == 0 {
panic(fmt.Errorf("jni: NewByteArray(%d) failed", len(content)))
}
elems := C._jni_GetByteArrayElements(e.env, jarr)
defer C._jni_ReleaseByteArrayElements(e.env, jarr, elems, 0)
elems := C._jni_GetByteArrayElements(env(e), jarr)
defer C._jni_ReleaseByteArrayElements(env(e), jarr, elems, 0)
backing := (*(*[1 << 30]byte)(unsafe.Pointer(elems)))[:len(content):len(content)]
copy(backing, content)
return ByteArray(jarr)
@ -177,7 +173,7 @@ func NewByteArray(e Env, content []byte) ByteArray {
// ClassLoader returns a reference to the Java ClassLoader associated
// with obj.
func ClassLoaderFor(e Env, obj Object) Object {
func ClassLoaderFor(e *Env, obj Object) Object {
cls := GetObjectClass(e, obj)
getClassLoader := GetMethodID(e, cls, "getClassLoader", "()Ljava/lang/ClassLoader;")
clsLoader, err := CallObjectMethod(e, Object(obj), getClassLoader)
@ -190,7 +186,7 @@ func ClassLoaderFor(e Env, obj Object) Object {
// LoadClass invokes the underlying ClassLoader's loadClass method and
// returns the class.
func LoadClass(e Env, loader Object, class string) (Class, error) {
func LoadClass(e *Env, loader Object, class string) (Class, error) {
cls := GetObjectClass(e, loader)
loadClass := GetMethodID(e, cls, "loadClass", "(Ljava/lang/String;)Ljava/lang/Class;")
name := JavaString(e, class)
@ -204,12 +200,12 @@ func LoadClass(e Env, loader Object, class string) (Class, error) {
// exception returns an error corresponding to the pending
// exception, and clears it. exceptionError returns nil if no
// exception is pending.
func exception(e Env) error {
thr := C._jni_ExceptionOccurred(e.env)
func exception(e *Env) error {
thr := C._jni_ExceptionOccurred(env(e))
if thr == 0 {
return nil
}
C._jni_ExceptionClear(e.env)
C._jni_ExceptionClear(env(e))
cls := GetObjectClass(e, Object(thr))
toString := GetMethodID(e, cls, "toString", "()Ljava/lang/String;")
msg, err := CallObjectMethod(e, Object(thr), toString)
@ -220,11 +216,11 @@ func exception(e Env) error {
}
// GetObjectClass returns the Java Class for an Object.
func GetObjectClass(e Env, obj Object) Class {
func GetObjectClass(e *Env, obj Object) Class {
if obj == 0 {
panic("null object")
}
cls := C._jni_GetObjectClass(e.env, C.jobject(obj))
cls := C._jni_GetObjectClass(env(e), C.jobject(obj))
if err := exception(e); err != nil {
// GetObjectClass should never fail.
panic(err)
@ -234,12 +230,12 @@ func GetObjectClass(e Env, obj Object) Class {
// GetStaticMethodID returns the id for a static method. It panics if the method
// wasn't found.
func GetStaticMethodID(e Env, cls Class, name, signature string) MethodID {
func GetStaticMethodID(e *Env, cls Class, name, signature string) MethodID {
mname := C.CString(name)
defer C.free(unsafe.Pointer(mname))
msig := C.CString(signature)
defer C.free(unsafe.Pointer(msig))
m := C._jni_GetStaticMethodID(e.env, C.jclass(cls), mname, msig)
m := C._jni_GetStaticMethodID(env(e), C.jclass(cls), mname, msig)
if err := exception(e); err != nil {
panic(err)
}
@ -248,43 +244,43 @@ func GetStaticMethodID(e Env, cls Class, name, signature string) MethodID {
// GetMethodID returns the id for a method. It panics if the method
// wasn't found.
func GetMethodID(e Env, cls Class, name, signature string) MethodID {
func GetMethodID(e *Env, cls Class, name, signature string) MethodID {
mname := C.CString(name)
defer C.free(unsafe.Pointer(mname))
msig := C.CString(signature)
defer C.free(unsafe.Pointer(msig))
m := C._jni_GetMethodID(e.env, C.jclass(cls), mname, msig)
m := C._jni_GetMethodID(env(e), C.jclass(cls), mname, msig)
if err := exception(e); err != nil {
panic(err)
}
return MethodID(m)
}
func NewGlobalRef(e Env, obj Object) Object {
return Object(C._jni_NewGlobalRef(e.env, C.jobject(obj)))
func NewGlobalRef(e *Env, obj Object) Object {
return Object(C._jni_NewGlobalRef(env(e), C.jobject(obj)))
}
func DeleteGlobalRef(e Env, obj Object) {
C._jni_DeleteGlobalRef(e.env, C.jobject(obj))
func DeleteGlobalRef(e *Env, obj Object) {
C._jni_DeleteGlobalRef(env(e), C.jobject(obj))
}
// JavaString converts the string to a JVM jstring.
func JavaString(e Env, str string) String {
func JavaString(e *Env, str string) String {
if str == "" {
return 0
}
utf16Chars := utf16.Encode([]rune(str))
res := C._jni_NewString(e.env, (*C.jchar)(unsafe.Pointer(&utf16Chars[0])), C.int(len(utf16Chars)))
res := C._jni_NewString(env(e), (*C.jchar)(unsafe.Pointer(&utf16Chars[0])), C.int(len(utf16Chars)))
return String(res)
}
// GoString converts the JVM jstring to a Go string.
func GoString(e Env, str String) string {
func GoString(e *Env, str String) string {
if str == 0 {
return ""
}
strlen := C._jni_GetStringLength(e.env, C.jstring(str))
chars := C._jni_GetStringChars(e.env, C.jstring(str))
strlen := C._jni_GetStringLength(env(e), C.jstring(str))
chars := C._jni_GetStringChars(env(e), C.jstring(str))
var utf16Chars []uint16
hdr := (*reflect.SliceHeader)(unsafe.Pointer(&utf16Chars))
hdr.Data = uintptr(unsafe.Pointer(chars))

Loading…
Cancel
Save