diff options
Diffstat (limited to 'tensorflow/contrib/lite/java/src')
6 files changed, 126 insertions, 5 deletions
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java index 5ee594dec4..7612be0ddd 100644 --- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java @@ -91,8 +91,9 @@ final class NativeInterpreterWrapper implements AutoCloseable { i, inputs.length)); } } + inferenceDurationNanoseconds = -1; long[] outputsHandles = - run(interpreterHandle, errorHandle, sizes, dataTypes, numsOfBytes, inputs); + run(interpreterHandle, errorHandle, sizes, dataTypes, numsOfBytes, inputs, this); if (outputsHandles == null || outputsHandles.length == 0) { throw new IllegalStateException("Interpreter has no outputs."); } @@ -109,7 +110,8 @@ final class NativeInterpreterWrapper implements AutoCloseable { Object[] sizes, int[] dtypes, int[] numsOfBytes, - Object[] values); + Object[] values, + NativeInterpreterWrapper wrapper); /** Resizes dimensions of a specific input. */ void resizeInput(int idx, int[] dims) { @@ -236,6 +238,14 @@ final class NativeInterpreterWrapper implements AutoCloseable { } } + /** + * Gets the last inference duration in nanoseconds. It returns null if there is no previous + * inference run or the last inference run failed. + */ + Long getLastNativeInferenceDurationNanoseconds() { + return (inferenceDurationNanoseconds < 0) ? null : inferenceDurationNanoseconds; + } + private static final int ERROR_BUFFER_SIZE = 512; private long errorHandle; @@ -246,6 +256,8 @@ final class NativeInterpreterWrapper implements AutoCloseable { private int inputSize; + private long inferenceDurationNanoseconds = -1; + private MappedByteBuffer modelByteBuffer; private Map<String, Integer> inputsIndexes; diff --git a/tensorflow/contrib/lite/java/src/main/native/duration_utils_jni.cc b/tensorflow/contrib/lite/java/src/main/native/duration_utils_jni.cc new file mode 100644 index 0000000000..0e08a04370 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/native/duration_utils_jni.cc @@ -0,0 +1,38 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <jni.h> +#include <time.h> + +namespace tflite { + +// Gets the elapsed wall-clock timespec. +timespec getCurrentTime() { + timespec time; + clock_gettime(CLOCK_MONOTONIC, &time); + return time; +} + +// Computes the time diff from two timespecs. Returns '-1' if 'stop' is earlier +// than 'start'. +jlong timespec_diff_nanoseconds(struct timespec* start, struct timespec* stop) { + jlong result = stop->tv_sec - start->tv_sec; + if (result < 0) return -1; + result = 1000000000 * result + (stop->tv_nsec - start->tv_nsec); + if (result < 0) return -1; + return result; +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc index c346f9f92e..e405df0745 100644 --- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc +++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc @@ -353,7 +353,7 @@ JNIEXPORT jlongArray JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_run( JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle, jobjectArray sizes, jintArray data_types, jintArray nums_of_bytes, - jobjectArray values) { + jobjectArray values, jobject wrapper) { tflite::Interpreter* interpreter = convertLongToInterpreter(env, interpreter_handle); if (interpreter == nullptr) return nullptr; @@ -384,6 +384,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_run( status = setInputs(env, interpreter, input_size, data_types, nums_of_bytes, values); if (status != kTfLiteOk) return nullptr; + timespec beforeInference = ::tflite::getCurrentTime(); // runs inference if (interpreter->Invoke() != kTfLiteOk) { throwException(env, kIllegalArgumentException, @@ -391,6 +392,15 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_run( error_reporter->CachedErrorMessage()); return nullptr; } + timespec afterInference = ::tflite::getCurrentTime(); + jclass wrapper_clazz = env->GetObjectClass(wrapper); + jfieldID fid = + env->GetFieldID(wrapper_clazz, "inferenceDurationNanoseconds", "J"); + if (fid != 0) { + env->SetLongField( + wrapper, fid, + ::tflite::timespec_diff_nanoseconds(&beforeInference, &afterInference)); + } // returns outputs const std::vector<int>& results = interpreter->outputs(); if (results.empty()) { diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h index c52a7e4e43..31c8f1bc88 100644 --- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h +++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h @@ -18,6 +18,7 @@ limitations under the License. #include <jni.h> #include <stdio.h> +#include <time.h> #include <vector> #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/interpreter.h" @@ -28,6 +29,9 @@ limitations under the License. namespace tflite { // This is to be provided at link-time by a library. extern std::unique_ptr<OpResolver> CreateOpResolver(); +extern timespec getCurrentTime(); +extern jlong timespec_diff_nanoseconds(struct timespec* start, + struct timespec* stop); } // namespace tflite #ifdef __cplusplus @@ -104,13 +108,14 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter( /* * Class: org_tensorflow_lite_NativeInterpreterWrapper * Method: - * Signature: (JJ[Ljava/lang/Object;[I[I[Ljava/lang/Object;)[J + * Signature: + * (JJ[Ljava/lang/Object;[I[I[Ljava/lang/Object;Lorg/tensorflow/lite/NativeInterpreterWrapper;)[J */ JNIEXPORT jlongArray JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_run( JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle, jobjectArray sizes, jintArray data_types, jintArray nums_of_bytes, - jobjectArray values); + jobjectArray values, jobject wrapper); /* * Class: org_tensorflow_lite_NativeInterpreterWrapper diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java index 90323555d8..8c1f2406f7 100644 --- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java @@ -417,4 +417,45 @@ public final class NativeInterpreterWrapperTest { assertThat(shape[1]).isEqualTo(3); assertThat(shape[2]).isEqualTo(1); } + + @Test + public void testGetInferenceLatency() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD}; + Tensor[] outputs = wrapper.run(inputs); + assertThat(outputs.length).isEqualTo(1); + assertThat(wrapper.getLastNativeInferenceDurationNanoseconds()).isGreaterThan(0L); + wrapper.close(); + } + + @Test + public void testGetInferenceLatencyWithNewWrapper() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); + assertThat(wrapper.getLastNativeInferenceDurationNanoseconds()).isNull(); + wrapper.close(); + } + + @Test + public void testGetLatencyAfterFailedInference() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD}; + try { + wrapper.run(inputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains("0-th input dimension should be [?,8,8,3], but found [?,8,7,3]"); + } + assertThat(wrapper.getLastNativeInferenceDurationNanoseconds()).isNull(); + wrapper.close(); + } } diff --git a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java index 8660cabf70..a5c13053d7 100644 --- a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java +++ b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java @@ -32,4 +32,19 @@ public class TestHelper { throw new IllegalArgumentException("Interpreter has not initialized; Failed to setUseNNAPI."); } } + + /** + * Gets the last inference duration in nanoseconds. It returns null if there is no previous + * inference run or the last inference run failed. + * + * @param interpreter an instance of {@code Interpreter}. If it is not initialized, an {@code + * IllegalArgumentException} will be thrown. + */ + public static Long getLastNativeInferenceDurationNanoseconds(Interpreter interpreter) { + if (interpreter != null && interpreter.wrapper != null) { + return interpreter.wrapper.getLastNativeInferenceDurationNanoseconds(); + } else { + throw new IllegalArgumentException("Interpreter has not initialized; Failed to get latency."); + } + } } |