diff options
Diffstat (limited to 'tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc')
-rw-r--r-- | tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc | 15 |
1 files changed, 11 insertions, 4 deletions
diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc index fdcf00a0a0..abb7320bc5 100644 --- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc +++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc @@ -59,7 +59,6 @@ std::vector<int> convertJIntArrayToVector(JNIEnv* env, jintArray inputs) { return outputs; } - int getDataType(TfLiteType data_type) { switch (data_type) { case kTfLiteFloat32: @@ -234,10 +233,18 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_useNNAPI(JNIEnv* env, } JNIEXPORT void JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_allowFp16PrecisionForFp32( + JNIEnv* env, jclass clazz, jlong handle, jboolean allow) { + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return; + interpreter->SetAllowFp16PrecisionForFp32(static_cast<bool>(allow)); +} + +JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_numThreads(JNIEnv* env, - jclass clazz, - jlong handle, - jint num_threads) { + jclass clazz, + jlong handle, + jint num_threads) { tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); if (interpreter == nullptr) return; interpreter->SetNumThreads(static_cast<int>(num_threads)); |