aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc')
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc15
1 files changed, 11 insertions, 4 deletions
diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
index fdcf00a0a0..abb7320bc5 100644
--- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
+++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
@@ -59,7 +59,6 @@ std::vector<int> convertJIntArrayToVector(JNIEnv* env, jintArray inputs) {
return outputs;
}
-
int getDataType(TfLiteType data_type) {
switch (data_type) {
case kTfLiteFloat32:
@@ -234,10 +233,18 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_useNNAPI(JNIEnv* env,
}
JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_allowFp16PrecisionForFp32(
+ JNIEnv* env, jclass clazz, jlong handle, jboolean allow) {
+ tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+ if (interpreter == nullptr) return;
+ interpreter->SetAllowFp16PrecisionForFp32(static_cast<bool>(allow));
+}
+
+JNIEXPORT void JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_numThreads(JNIEnv* env,
- jclass clazz,
- jlong handle,
- jint num_threads) {
+ jclass clazz,
+ jlong handle,
+ jint num_threads) {
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return;
interpreter->SetNumThreads(static_cast<int>(num_threads));