diff options
author | Jared Duke <jdduke@google.com> | 2018-10-03 10:18:59 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-03 10:22:41 -0700 |
commit | 022af5300701d457d848e60ea511dd8d05f68738 (patch) | |
tree | ff1a6b445c874fbd482a623a991b0502f4b8f3ed /tensorflow/contrib/lite/java | |
parent | 2af8fd975aaf5c70ebb396895fa15a8f034a8440 (diff) |
Fix TfLiteTensor invalidation issue when using the Java API
Fix an issue where the Java Tensor class would hold a reference
to an invalidated TfLiteTensor instance. This issue was manifest
in certain models that add temporary tensors during execution.
PiperOrigin-RevId: 215582842
Diffstat (limited to 'tensorflow/contrib/lite/java')
7 files changed, 129 insertions, 50 deletions
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java index 9bc44bf797..6f03e7853a 100644 --- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java @@ -18,7 +18,6 @@ package org.tensorflow.lite; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.MappedByteBuffer; -import java.util.Arrays; import java.util.HashMap; import java.util.Map; @@ -83,6 +82,19 @@ final class NativeInterpreterWrapper implements AutoCloseable { /** Releases resources associated with this {@code NativeInterpreterWrapper}. */ @Override public void close() { + // Close the tensors first as they may reference the native interpreter. + for (int i = 0; i < inputTensors.length; ++i) { + if (inputTensors[i] != null) { + inputTensors[i].close(); + inputTensors[i] = null; + } + } + for (int i = 0; i < outputTensors.length; ++i) { + if (outputTensors[i] != null) { + outputTensors[i].close(); + outputTensors[i] = null; + } + } delete(errorHandle, modelHandle, interpreterHandle); errorHandle = 0; modelHandle = 0; @@ -91,8 +103,6 @@ final class NativeInterpreterWrapper implements AutoCloseable { inputsIndexes = null; outputsIndexes = null; isMemoryAllocated = false; - Arrays.fill(inputTensors, null); - Arrays.fill(outputTensors, null); } /** Sets inputs, runs model inference and returns outputs. */ @@ -260,7 +270,8 @@ final class NativeInterpreterWrapper implements AutoCloseable { Tensor inputTensor = inputTensors[index]; if (inputTensor == null) { inputTensor = - inputTensors[index] = Tensor.fromHandle(getInputTensor(interpreterHandle, index)); + inputTensors[index] = + Tensor.fromIndex(interpreterHandle, getInputTensorIndex(interpreterHandle, index)); } return inputTensor; } @@ -282,7 +293,8 @@ final class NativeInterpreterWrapper implements AutoCloseable { Tensor outputTensor = outputTensors[index]; if (outputTensor == null) { outputTensor = - outputTensors[index] = Tensor.fromHandle(getOutputTensor(interpreterHandle, index)); + outputTensors[index] = + Tensor.fromIndex(interpreterHandle, getOutputTensorIndex(interpreterHandle, index)); } return outputTensor; } @@ -317,9 +329,9 @@ final class NativeInterpreterWrapper implements AutoCloseable { private static native long allocateTensors(long interpreterHandle, long errorHandle); - private static native long getInputTensor(long interpreterHandle, int inputIdx); + private static native int getInputTensorIndex(long interpreterHandle, int inputIdx); - private static native long getOutputTensor(long interpreterHandle, int outputIdx); + private static native int getOutputTensorIndex(long interpreterHandle, int outputIdx); private static native int getInputCount(long interpreterHandle); diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java index f174178d98..6ca47aa3ed 100644 --- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java @@ -23,13 +23,26 @@ import java.util.Arrays; /** * A typed multi-dimensional array used in Tensorflow Lite. * - * <p>The native handle of a {@code Tensor} belongs to {@code NativeInterpreterWrapper}, thus not - * needed to be closed here. + * <p>The native handle of a {@code Tensor} is managed by {@code NativeInterpreterWrapper}, and does + * not needed to be closed by the client. However, once the {@code NativeInterpreterWrapper} has + * been closed, the tensor handle will be invalidated. */ public final class Tensor { - static Tensor fromHandle(long nativeHandle) { - return new Tensor(nativeHandle); + /** + * Creates a Tensor wrapper from the provided interpreter instance and tensor index. + * + * <p>The caller is responsible for closing the created wrapper, and ensuring the provided + * native interpreter is valid until the tensor is closed. + */ + static Tensor fromIndex(long nativeInterpreterHandle, int tensorIndex) { + return new Tensor(create(nativeInterpreterHandle, tensorIndex)); + } + + /** Disposes of any resources used by the Tensor wrapper. */ + void close() { + delete(nativeHandle); + nativeHandle = 0; } /** Returns the {@link DataType} of elements stored in the Tensor. */ @@ -235,7 +248,7 @@ public final class Tensor { return o instanceof ByteBuffer; } - private final long nativeHandle; + private long nativeHandle; private final DataType dtype; private int[] shapeCopy; @@ -249,6 +262,10 @@ public final class Tensor { return buffer(nativeHandle).order(ByteOrder.nativeOrder()); } + private static native long create(long interpreterHandle, int tensorIndex); + + private static native void delete(long handle); + private static native ByteBuffer buffer(long handle); private static native void writeDirectBuffer(long handle, ByteBuffer src); diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc index abb7320bc5..4dc73fbcf8 100644 --- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc +++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc @@ -159,26 +159,20 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_allocateTensors( } } -JNIEXPORT jlong JNICALL -Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensor(JNIEnv* env, - jclass clazz, - jlong handle, - jint index) { +JNIEXPORT jint JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensorIndex( + JNIEnv* env, jclass clazz, jlong handle, jint input_index) { tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); if (interpreter == nullptr) return 0; - return reinterpret_cast<jlong>( - interpreter->tensor(interpreter->inputs()[index])); + return interpreter->inputs()[input_index]; } -JNIEXPORT jlong JNICALL -Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensor(JNIEnv* env, - jclass clazz, - jlong handle, - jint index) { +JNIEXPORT jint JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensorIndex( + JNIEnv* env, jclass clazz, jlong handle, jint output_index) { tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); if (interpreter == nullptr) return 0; - return reinterpret_cast<jlong>( - interpreter->tensor(interpreter->outputs()[index])); + return interpreter->outputs()[output_index]; } JNIEXPORT jint JNICALL diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h index aa809dff8a..f8f3e7028c 100644 --- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h +++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h @@ -46,25 +46,21 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_allocateTensors( /* * Class: org_tensorflow_lite_NativeInterpreterWrapper - * Method: getInputTensor - * Signature: (JI)J + * Method: getInputTensorIndex + * Signature: (JI)I */ -JNIEXPORT jlong JNICALL -Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensor(JNIEnv* env, - jclass clazz, - jlong handle, - jint index); +JNIEXPORT jint JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensorIndex( + JNIEnv* env, jclass clazz, jlong handle, jint input_index); /* * Class: org_tensorflow_lite_NativeInterpreterWrapper - * Method: getOutputTensor - * Signature: (JI)J + * Method: getOutputTensorIndex + * Signature: (JI)I */ -JNIEXPORT jlong JNICALL -Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensor(JNIEnv* env, - jclass clazz, - jlong handle, - jint index); +JNIEXPORT jint JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensorIndex( + JNIEnv* env, jclass clazz, jlong handle, jint output_index); /* * Class: org_tensorflow_lite_NativeInterpreterWrapper diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc index 7ff96a3172..d3378f5f14 100644 --- a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc +++ b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc @@ -16,17 +16,36 @@ limitations under the License. #include "tensorflow/contrib/lite/java/src/main/native/tensor_jni.h" #include <cstring> #include <memory> +#include "tensorflow/contrib/lite/interpreter.h" #include "tensorflow/contrib/lite/java/src/main/native/exception_jni.h" namespace { -TfLiteTensor* convertLongToTensor(JNIEnv* env, jlong handle) { +// Convenience handle for obtaining a TfLiteTensor given an interpreter and +// tensor index. +// +// Historically, the Java Tensor class used a TfLiteTensor pointer as its native +// handle. However, this approach isn't generally safe, as the interpreter may +// invalidate all TfLiteTensor* handles during inference or allocation. +class TensorHandle { + public: + TensorHandle(tflite::Interpreter* interpreter, int tensor_index) + : interpreter_(interpreter), tensor_index_(tensor_index) {} + + TfLiteTensor* tensor() const { return interpreter_->tensor(tensor_index_); } + + private: + tflite::Interpreter* const interpreter_; + const int tensor_index_; +}; + +TfLiteTensor* GetTensorFromHandle(JNIEnv* env, jlong handle) { if (handle == 0) { throwException(env, kIllegalArgumentException, "Internal error: Invalid handle to TfLiteTensor."); return nullptr; } - return reinterpret_cast<TfLiteTensor*>(handle); + return reinterpret_cast<TensorHandle*>(handle)->tensor(); } size_t elementByteSize(TfLiteType data_type) { @@ -192,10 +211,23 @@ size_t writeMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type, } // namespace +JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_Tensor_create( + JNIEnv* env, jclass clazz, jlong interpreter_handle, jint tensor_index) { + tflite::Interpreter* interpreter = + reinterpret_cast<tflite::Interpreter*>(interpreter_handle); + return reinterpret_cast<jlong>(new TensorHandle(interpreter, tensor_index)); +} + +JNIEXPORT void JNICALL Java_org_tensorflow_lite_Tensor_delete(JNIEnv* env, + jclass clazz, + jlong handle) { + delete reinterpret_cast<TensorHandle*>(handle); +} + JNIEXPORT jobject JNICALL Java_org_tensorflow_lite_Tensor_buffer(JNIEnv* env, jclass clazz, jlong handle) { - TfLiteTensor* tensor = convertLongToTensor(env, handle); + TfLiteTensor* tensor = GetTensorFromHandle(env, handle); if (tensor == nullptr) return nullptr; if (tensor->data.raw == nullptr) { throwException(env, kIllegalArgumentException, @@ -208,7 +240,7 @@ JNIEXPORT jobject JNICALL Java_org_tensorflow_lite_Tensor_buffer(JNIEnv* env, JNIEXPORT void JNICALL Java_org_tensorflow_lite_Tensor_writeDirectBuffer( JNIEnv* env, jclass clazz, jlong handle, jobject src) { - TfLiteTensor* tensor = convertLongToTensor(env, handle); + TfLiteTensor* tensor = GetTensorFromHandle(env, handle); if (tensor == nullptr) return; char* src_data_raw = static_cast<char*>(env->GetDirectBufferAddress(src)); @@ -226,7 +258,7 @@ Java_org_tensorflow_lite_Tensor_readMultiDimensionalArray(JNIEnv* env, jclass clazz, jlong handle, jobject value) { - TfLiteTensor* tensor = convertLongToTensor(env, handle); + TfLiteTensor* tensor = GetTensorFromHandle(env, handle); if (tensor == nullptr) return; int num_dims = tensor->dims->size; if (num_dims == 0) { @@ -243,7 +275,7 @@ Java_org_tensorflow_lite_Tensor_writeMultiDimensionalArray(JNIEnv* env, jclass clazz, jlong handle, jobject src) { - TfLiteTensor* tensor = convertLongToTensor(env, handle); + TfLiteTensor* tensor = GetTensorFromHandle(env, handle); if (tensor == nullptr) return; if (tensor->data.raw == nullptr) { throwException(env, kIllegalArgumentException, @@ -262,14 +294,14 @@ Java_org_tensorflow_lite_Tensor_writeMultiDimensionalArray(JNIEnv* env, JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_dtype(JNIEnv* env, jclass clazz, jlong handle) { - TfLiteTensor* tensor = convertLongToTensor(env, handle); + TfLiteTensor* tensor = GetTensorFromHandle(env, handle); if (tensor == nullptr) return 0; return static_cast<jint>(tensor->type); } JNIEXPORT jintArray JNICALL Java_org_tensorflow_lite_Tensor_shape(JNIEnv* env, jclass clazz, jlong handle) { - TfLiteTensor* tensor = convertLongToTensor(env, handle); + TfLiteTensor* tensor = GetTensorFromHandle(env, handle); if (tensor == nullptr) return nullptr; int num_dims = tensor->dims->size; jintArray result = env->NewIntArray(num_dims); @@ -280,7 +312,7 @@ Java_org_tensorflow_lite_Tensor_shape(JNIEnv* env, jclass clazz, jlong handle) { JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_numBytes(JNIEnv* env, jclass clazz, jlong handle) { - const TfLiteTensor* tensor = convertLongToTensor(env, handle); + const TfLiteTensor* tensor = GetTensorFromHandle(env, handle); if (tensor == nullptr) return 0; return static_cast<jint>(tensor->bytes); } diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h index 2f73128bdf..c5e9690e9a 100644 --- a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h +++ b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h @@ -25,6 +25,23 @@ extern "C" { /* * Class: org_tensorflow_lite_Tensor + * Method: create + * Signature: (JI)J + */ +JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_Tensor_create( + JNIEnv* env, jclass clazz, jlong interpreter_handle, jint tensor_index); + +/* + * Class: org_tensorflow_lite_Tensor + * Method: delete + * Signature: (J) + */ +JNIEXPORT void JNICALL Java_org_tensorflow_lite_Tensor_delete(JNIEnv* env, + jclass clazz, + jlong handle); + +/* + * Class: org_tensorflow_lite_Tensor * Method: buffer * Signature: (J)Ljava/nio/ByteBuffer; */ diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java index 85ad393d89..56a38ea3e2 100644 --- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java @@ -182,7 +182,7 @@ public final class TensorTest { dataType = Tensor.dataTypeOf(testFloatArray); assertThat(dataType).isEqualTo(DataType.FLOAT32); float[][] testMultiDimArray = {testFloatArray, testFloatArray, testFloatArray}; - dataType = Tensor.dataTypeOf(testFloatArray); + dataType = Tensor.dataTypeOf(testMultiDimArray); assertThat(dataType).isEqualTo(DataType.FLOAT32); try { double[] testDoubleArray = {0.783, 0.251}; @@ -238,4 +238,15 @@ public final class TensorTest { assertThat(shape[1]).isEqualTo(3); assertThat(shape[2]).isEqualTo(1); } + + @Test + public void testUseAfterClose() { + tensor.close(); + try { + tensor.numBytes(); + fail(); + } catch (IllegalArgumentException e) { + // Expected failure. + } + } } |