diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-26 14:22:38 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-26 14:27:23 -0700 |
commit | 9280b3c8a41150022d3ea508f01959ac954c9f73 (patch) | |
tree | f5e6dab79e0fba1705524324d2591caee6d1ea2c /tensorflow/contrib | |
parent | 72b927960625cd2920fea06e242df1ff0d220c77 (diff) |
Add an experimental Java API to allow half precision for FP32 calculation.
PiperOrigin-RevId: 214668283
Diffstat (limited to 'tensorflow/contrib')
5 files changed, 49 insertions, 7 deletions
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java index ffb04496cb..eacfa0c827 100644 --- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java @@ -74,8 +74,19 @@ public final class Interpreter implements AutoCloseable { return this; } + /** + * Sets whether to allow float16 precision for FP32 calculation when possible. Defaults to false + * (disallow). + * WARNING: This is an experimental API and subject to change. + */ + public Options setAllowFp16PrecisionForFp32(boolean allow) { + this.allowFp16PrecisionForFp32 = allow; + return this; + } + int numThreads = -1; boolean useNNAPI = false; + boolean allowFp16PrecisionForFp32 = false; } /** @@ -256,8 +267,9 @@ public final class Interpreter implements AutoCloseable { /** * Returns native inference timing. - * <p>IllegalArgumentException will be thrown if the model is not initialized by the - * {@link Interpreter}. + * + * <p>IllegalArgumentException will be thrown if the model is not initialized by the {@link + * Interpreter}. */ public Long getLastNativeInferenceDurationNanoseconds() { checkNotClosed(); diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java index 6feff9a618..9bc44bf797 100644 --- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java @@ -45,6 +45,9 @@ final class NativeInterpreterWrapper implements AutoCloseable { isMemoryAllocated = true; inputTensors = new Tensor[getInputCount(interpreterHandle)]; outputTensors = new Tensor[getOutputCount(interpreterHandle)]; + if (options.allowFp16PrecisionForFp32) { + setAllowFp16PrecisionForFp32(options.allowFp16PrecisionForFp32); + } } NativeInterpreterWrapper(ByteBuffer byteBuffer) { @@ -72,6 +75,9 @@ final class NativeInterpreterWrapper implements AutoCloseable { if (options.useNNAPI) { setUseNNAPI(options.useNNAPI); } + if (options.allowFp16PrecisionForFp32) { + setAllowFp16PrecisionForFp32(options.allowFp16PrecisionForFp32); + } } /** Releases resources associated with this {@code NativeInterpreterWrapper}. */ @@ -159,6 +165,10 @@ final class NativeInterpreterWrapper implements AutoCloseable { useNNAPI(interpreterHandle, useNNAPI); } + void setAllowFp16PrecisionForFp32(boolean allow) { + allowFp16PrecisionForFp32(interpreterHandle, allow); + } + void setNumThreads(int numThreads) { numThreads(interpreterHandle, numThreads); } @@ -323,6 +333,8 @@ final class NativeInterpreterWrapper implements AutoCloseable { private static native void numThreads(long interpreterHandle, int numThreads); + private static native void allowFp16PrecisionForFp32(long interpreterHandle, boolean allow); + private static native long createErrorReporter(int size); private static native long createModel(String modelPathOrBuffer, long errorHandle); diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc index fdcf00a0a0..abb7320bc5 100644 --- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc +++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc @@ -59,7 +59,6 @@ std::vector<int> convertJIntArrayToVector(JNIEnv* env, jintArray inputs) { return outputs; } - int getDataType(TfLiteType data_type) { switch (data_type) { case kTfLiteFloat32: @@ -234,10 +233,18 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_useNNAPI(JNIEnv* env, } JNIEXPORT void JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_allowFp16PrecisionForFp32( + JNIEnv* env, jclass clazz, jlong handle, jboolean allow) { + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return; + interpreter->SetAllowFp16PrecisionForFp32(static_cast<bool>(allow)); +} + +JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_numThreads(JNIEnv* env, - jclass clazz, - jlong handle, - jint num_threads) { + jclass clazz, + jlong handle, + jint num_threads) { tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); if (interpreter == nullptr) return; interpreter->SetNumThreads(static_cast<int>(num_threads)); diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h index 06b35d77c8..aa809dff8a 100644 --- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h +++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h @@ -120,6 +120,15 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_useNNAPI(JNIEnv* env, /* * Class: org_tensorflow_lite_NativeInterpreterWrapper * Method: + * Signature: (JZ)V + */ +JNIEXPORT void JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_allowFp16PrecisionForFp32( + JNIEnv* env, jclass clazz, jlong handle, jboolean allow); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: * Signature: (JI)V */ JNIEXPORT void JNICALL diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java index dfdd7d22b0..fdd5063156 100644 --- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java @@ -323,7 +323,9 @@ public final class InterpreterTest { MappedByteBuffer mappedByteBuffer = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileChannel.size()); Interpreter interpreter = - new Interpreter(mappedByteBuffer, new Interpreter.Options().setUseNNAPI(true)); + new Interpreter( + mappedByteBuffer, + new Interpreter.Options().setUseNNAPI(true).setAllowFp16PrecisionForFp32(true)); float[] oneD = {1.23f, 6.54f, 7.81f}; float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; |