aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-26 14:22:38 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-26 14:27:23 -0700
commit9280b3c8a41150022d3ea508f01959ac954c9f73 (patch)
treef5e6dab79e0fba1705524324d2591caee6d1ea2c /tensorflow/contrib
parent72b927960625cd2920fea06e242df1ff0d220c77 (diff)
Add an experimental Java API to allow half precision for FP32 calculation.
PiperOrigin-RevId: 214668283
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java16
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java12
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc15
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h9
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java4
5 files changed, 49 insertions, 7 deletions
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
index ffb04496cb..eacfa0c827 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
@@ -74,8 +74,19 @@ public final class Interpreter implements AutoCloseable {
return this;
}
+ /**
+ * Sets whether to allow float16 precision for FP32 calculation when possible. Defaults to false
+ * (disallow).
+ * WARNING: This is an experimental API and subject to change.
+ */
+ public Options setAllowFp16PrecisionForFp32(boolean allow) {
+ this.allowFp16PrecisionForFp32 = allow;
+ return this;
+ }
+
int numThreads = -1;
boolean useNNAPI = false;
+ boolean allowFp16PrecisionForFp32 = false;
}
/**
@@ -256,8 +267,9 @@ public final class Interpreter implements AutoCloseable {
/**
* Returns native inference timing.
- * <p>IllegalArgumentException will be thrown if the model is not initialized by the
- * {@link Interpreter}.
+ *
+ * <p>IllegalArgumentException will be thrown if the model is not initialized by the {@link
+ * Interpreter}.
*/
public Long getLastNativeInferenceDurationNanoseconds() {
checkNotClosed();
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
index 6feff9a618..9bc44bf797 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
@@ -45,6 +45,9 @@ final class NativeInterpreterWrapper implements AutoCloseable {
isMemoryAllocated = true;
inputTensors = new Tensor[getInputCount(interpreterHandle)];
outputTensors = new Tensor[getOutputCount(interpreterHandle)];
+ if (options.allowFp16PrecisionForFp32) {
+ setAllowFp16PrecisionForFp32(options.allowFp16PrecisionForFp32);
+ }
}
NativeInterpreterWrapper(ByteBuffer byteBuffer) {
@@ -72,6 +75,9 @@ final class NativeInterpreterWrapper implements AutoCloseable {
if (options.useNNAPI) {
setUseNNAPI(options.useNNAPI);
}
+ if (options.allowFp16PrecisionForFp32) {
+ setAllowFp16PrecisionForFp32(options.allowFp16PrecisionForFp32);
+ }
}
/** Releases resources associated with this {@code NativeInterpreterWrapper}. */
@@ -159,6 +165,10 @@ final class NativeInterpreterWrapper implements AutoCloseable {
useNNAPI(interpreterHandle, useNNAPI);
}
+ void setAllowFp16PrecisionForFp32(boolean allow) {
+ allowFp16PrecisionForFp32(interpreterHandle, allow);
+ }
+
void setNumThreads(int numThreads) {
numThreads(interpreterHandle, numThreads);
}
@@ -323,6 +333,8 @@ final class NativeInterpreterWrapper implements AutoCloseable {
private static native void numThreads(long interpreterHandle, int numThreads);
+ private static native void allowFp16PrecisionForFp32(long interpreterHandle, boolean allow);
+
private static native long createErrorReporter(int size);
private static native long createModel(String modelPathOrBuffer, long errorHandle);
diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
index fdcf00a0a0..abb7320bc5 100644
--- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
+++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
@@ -59,7 +59,6 @@ std::vector<int> convertJIntArrayToVector(JNIEnv* env, jintArray inputs) {
return outputs;
}
-
int getDataType(TfLiteType data_type) {
switch (data_type) {
case kTfLiteFloat32:
@@ -234,10 +233,18 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_useNNAPI(JNIEnv* env,
}
JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_allowFp16PrecisionForFp32(
+ JNIEnv* env, jclass clazz, jlong handle, jboolean allow) {
+ tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+ if (interpreter == nullptr) return;
+ interpreter->SetAllowFp16PrecisionForFp32(static_cast<bool>(allow));
+}
+
+JNIEXPORT void JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_numThreads(JNIEnv* env,
- jclass clazz,
- jlong handle,
- jint num_threads) {
+ jclass clazz,
+ jlong handle,
+ jint num_threads) {
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return;
interpreter->SetNumThreads(static_cast<int>(num_threads));
diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
index 06b35d77c8..aa809dff8a 100644
--- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
+++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
@@ -120,6 +120,15 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_useNNAPI(JNIEnv* env,
/*
* Class: org_tensorflow_lite_NativeInterpreterWrapper
* Method:
+ * Signature: (JZ)V
+ */
+JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_allowFp16PrecisionForFp32(
+ JNIEnv* env, jclass clazz, jlong handle, jboolean allow);
+
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method:
* Signature: (JI)V
*/
JNIEXPORT void JNICALL
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
index dfdd7d22b0..fdd5063156 100644
--- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
@@ -323,7 +323,9 @@ public final class InterpreterTest {
MappedByteBuffer mappedByteBuffer =
fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileChannel.size());
Interpreter interpreter =
- new Interpreter(mappedByteBuffer, new Interpreter.Options().setUseNNAPI(true));
+ new Interpreter(
+ mappedByteBuffer,
+ new Interpreter.Options().setUseNNAPI(true).setAllowFp16PrecisionForFp32(true));
float[] oneD = {1.23f, 6.54f, 7.81f};
float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};