aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/java/src
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-02-27 14:37:14 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-27 14:41:44 -0800
commit80b6956b7cf4a092ff0780d133cd2faad4cda704 (patch)
tree15da8ae9e81ccf07854b7538f454fa1240a2e4d8 /tensorflow/contrib/lite/java/src
parent246cad289498357523517b67a3f214960dfa0f92 (diff)
Added a TFLite Java API to get last inference latency in nanoseconds.
PiperOrigin-RevId: 187234119
Diffstat (limited to 'tensorflow/contrib/lite/java/src')
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java16
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/duration_utils_jni.cc38
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc12
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h9
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java41
-rw-r--r--tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java15
6 files changed, 126 insertions, 5 deletions
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
index 5ee594dec4..7612be0ddd 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
@@ -91,8 +91,9 @@ final class NativeInterpreterWrapper implements AutoCloseable {
i, inputs.length));
}
}
+ inferenceDurationNanoseconds = -1;
long[] outputsHandles =
- run(interpreterHandle, errorHandle, sizes, dataTypes, numsOfBytes, inputs);
+ run(interpreterHandle, errorHandle, sizes, dataTypes, numsOfBytes, inputs, this);
if (outputsHandles == null || outputsHandles.length == 0) {
throw new IllegalStateException("Interpreter has no outputs.");
}
@@ -109,7 +110,8 @@ final class NativeInterpreterWrapper implements AutoCloseable {
Object[] sizes,
int[] dtypes,
int[] numsOfBytes,
- Object[] values);
+ Object[] values,
+ NativeInterpreterWrapper wrapper);
/** Resizes dimensions of a specific input. */
void resizeInput(int idx, int[] dims) {
@@ -236,6 +238,14 @@ final class NativeInterpreterWrapper implements AutoCloseable {
}
}
+ /**
+ * Gets the last inference duration in nanoseconds. It returns null if there is no previous
+ * inference run or the last inference run failed.
+ */
+ Long getLastNativeInferenceDurationNanoseconds() {
+ return (inferenceDurationNanoseconds < 0) ? null : inferenceDurationNanoseconds;
+ }
+
private static final int ERROR_BUFFER_SIZE = 512;
private long errorHandle;
@@ -246,6 +256,8 @@ final class NativeInterpreterWrapper implements AutoCloseable {
private int inputSize;
+ private long inferenceDurationNanoseconds = -1;
+
private MappedByteBuffer modelByteBuffer;
private Map<String, Integer> inputsIndexes;
diff --git a/tensorflow/contrib/lite/java/src/main/native/duration_utils_jni.cc b/tensorflow/contrib/lite/java/src/main/native/duration_utils_jni.cc
new file mode 100644
index 0000000000..0e08a04370
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/main/native/duration_utils_jni.cc
@@ -0,0 +1,38 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <jni.h>
+#include <time.h>
+
+namespace tflite {
+
+// Gets the elapsed wall-clock timespec.
+timespec getCurrentTime() {
+ timespec time;
+ clock_gettime(CLOCK_MONOTONIC, &time);
+ return time;
+}
+
+// Computes the time diff from two timespecs. Returns '-1' if 'stop' is earlier
+// than 'start'.
+jlong timespec_diff_nanoseconds(struct timespec* start, struct timespec* stop) {
+ jlong result = stop->tv_sec - start->tv_sec;
+ if (result < 0) return -1;
+ result = 1000000000 * result + (stop->tv_nsec - start->tv_nsec);
+ if (result < 0) return -1;
+ return result;
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
index c346f9f92e..e405df0745 100644
--- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
+++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
@@ -353,7 +353,7 @@ JNIEXPORT jlongArray JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_run(
JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle,
jobjectArray sizes, jintArray data_types, jintArray nums_of_bytes,
- jobjectArray values) {
+ jobjectArray values, jobject wrapper) {
tflite::Interpreter* interpreter =
convertLongToInterpreter(env, interpreter_handle);
if (interpreter == nullptr) return nullptr;
@@ -384,6 +384,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_run(
status = setInputs(env, interpreter, input_size, data_types, nums_of_bytes,
values);
if (status != kTfLiteOk) return nullptr;
+ timespec beforeInference = ::tflite::getCurrentTime();
// runs inference
if (interpreter->Invoke() != kTfLiteOk) {
throwException(env, kIllegalArgumentException,
@@ -391,6 +392,15 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_run(
error_reporter->CachedErrorMessage());
return nullptr;
}
+ timespec afterInference = ::tflite::getCurrentTime();
+ jclass wrapper_clazz = env->GetObjectClass(wrapper);
+ jfieldID fid =
+ env->GetFieldID(wrapper_clazz, "inferenceDurationNanoseconds", "J");
+ if (fid != 0) {
+ env->SetLongField(
+ wrapper, fid,
+ ::tflite::timespec_diff_nanoseconds(&beforeInference, &afterInference));
+ }
// returns outputs
const std::vector<int>& results = interpreter->outputs();
if (results.empty()) {
diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
index c52a7e4e43..31c8f1bc88 100644
--- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
+++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <jni.h>
#include <stdio.h>
+#include <time.h>
#include <vector>
#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/interpreter.h"
@@ -28,6 +29,9 @@ limitations under the License.
namespace tflite {
// This is to be provided at link-time by a library.
extern std::unique_ptr<OpResolver> CreateOpResolver();
+extern timespec getCurrentTime();
+extern jlong timespec_diff_nanoseconds(struct timespec* start,
+ struct timespec* stop);
} // namespace tflite
#ifdef __cplusplus
@@ -104,13 +108,14 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter(
/*
* Class: org_tensorflow_lite_NativeInterpreterWrapper
* Method:
- * Signature: (JJ[Ljava/lang/Object;[I[I[Ljava/lang/Object;)[J
+ * Signature:
+ * (JJ[Ljava/lang/Object;[I[I[Ljava/lang/Object;Lorg/tensorflow/lite/NativeInterpreterWrapper;)[J
*/
JNIEXPORT jlongArray JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_run(
JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle,
jobjectArray sizes, jintArray data_types, jintArray nums_of_bytes,
- jobjectArray values);
+ jobjectArray values, jobject wrapper);
/*
* Class: org_tensorflow_lite_NativeInterpreterWrapper
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
index 90323555d8..8c1f2406f7 100644
--- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
@@ -417,4 +417,45 @@ public final class NativeInterpreterWrapperTest {
assertThat(shape[1]).isEqualTo(3);
assertThat(shape[2]).isEqualTo(1);
}
+
+ @Test
+ public void testGetInferenceLatency() {
+ NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
+ float[] oneD = {1.23f, 6.54f, 7.81f};
+ float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
+ float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
+ float[][][][] fourD = {threeD, threeD};
+ Object[] inputs = {fourD};
+ Tensor[] outputs = wrapper.run(inputs);
+ assertThat(outputs.length).isEqualTo(1);
+ assertThat(wrapper.getLastNativeInferenceDurationNanoseconds()).isGreaterThan(0L);
+ wrapper.close();
+ }
+
+ @Test
+ public void testGetInferenceLatencyWithNewWrapper() {
+ NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
+ assertThat(wrapper.getLastNativeInferenceDurationNanoseconds()).isNull();
+ wrapper.close();
+ }
+
+ @Test
+ public void testGetLatencyAfterFailedInference() {
+ NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
+ float[] oneD = {1.23f, 6.54f, 7.81f};
+ float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD};
+ float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
+ float[][][][] fourD = {threeD, threeD};
+ Object[] inputs = {fourD};
+ try {
+ wrapper.run(inputs);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e)
+ .hasMessageThat()
+ .contains("0-th input dimension should be [?,8,8,3], but found [?,8,7,3]");
+ }
+ assertThat(wrapper.getLastNativeInferenceDurationNanoseconds()).isNull();
+ wrapper.close();
+ }
}
diff --git a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java
index 8660cabf70..a5c13053d7 100644
--- a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java
+++ b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java
@@ -32,4 +32,19 @@ public class TestHelper {
throw new IllegalArgumentException("Interpreter has not initialized; Failed to setUseNNAPI.");
}
}
+
+ /**
+ * Gets the last inference duration in nanoseconds. It returns null if there is no previous
+ * inference run or the last inference run failed.
+ *
+ * @param interpreter an instance of {@code Interpreter}. If it is not initialized, an {@code
+ * IllegalArgumentException} will be thrown.
+ */
+ public static Long getLastNativeInferenceDurationNanoseconds(Interpreter interpreter) {
+ if (interpreter != null && interpreter.wrapper != null) {
+ return interpreter.wrapper.getLastNativeInferenceDurationNanoseconds();
+ } else {
+ throw new IllegalArgumentException("Interpreter has not initialized; Failed to get latency.");
+ }
+ }
}