aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/java
diff options
context:
space:
mode:
authorGravatar Jared Duke <jdduke@google.com>2018-10-03 10:18:59 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-03 10:22:41 -0700
commit022af5300701d457d848e60ea511dd8d05f68738 (patch)
treeff1a6b445c874fbd482a623a991b0502f4b8f3ed /tensorflow/contrib/lite/java
parent2af8fd975aaf5c70ebb396895fa15a8f034a8440 (diff)
Fix TfLiteTensor invalidation issue when using the Java API
Fix an issue where the Java Tensor class would hold a reference to an invalidated TfLiteTensor instance. This issue was manifest in certain models that add temporary tensors during execution. PiperOrigin-RevId: 215582842
Diffstat (limited to 'tensorflow/contrib/lite/java')
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java26
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java27
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc22
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h24
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc50
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/tensor_jni.h17
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java13
7 files changed, 129 insertions, 50 deletions
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
index 9bc44bf797..6f03e7853a 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
@@ -18,7 +18,6 @@ package org.tensorflow.lite;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
-import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
@@ -83,6 +82,19 @@ final class NativeInterpreterWrapper implements AutoCloseable {
/** Releases resources associated with this {@code NativeInterpreterWrapper}. */
@Override
public void close() {
+ // Close the tensors first as they may reference the native interpreter.
+ for (int i = 0; i < inputTensors.length; ++i) {
+ if (inputTensors[i] != null) {
+ inputTensors[i].close();
+ inputTensors[i] = null;
+ }
+ }
+ for (int i = 0; i < outputTensors.length; ++i) {
+ if (outputTensors[i] != null) {
+ outputTensors[i].close();
+ outputTensors[i] = null;
+ }
+ }
delete(errorHandle, modelHandle, interpreterHandle);
errorHandle = 0;
modelHandle = 0;
@@ -91,8 +103,6 @@ final class NativeInterpreterWrapper implements AutoCloseable {
inputsIndexes = null;
outputsIndexes = null;
isMemoryAllocated = false;
- Arrays.fill(inputTensors, null);
- Arrays.fill(outputTensors, null);
}
/** Sets inputs, runs model inference and returns outputs. */
@@ -260,7 +270,8 @@ final class NativeInterpreterWrapper implements AutoCloseable {
Tensor inputTensor = inputTensors[index];
if (inputTensor == null) {
inputTensor =
- inputTensors[index] = Tensor.fromHandle(getInputTensor(interpreterHandle, index));
+ inputTensors[index] =
+ Tensor.fromIndex(interpreterHandle, getInputTensorIndex(interpreterHandle, index));
}
return inputTensor;
}
@@ -282,7 +293,8 @@ final class NativeInterpreterWrapper implements AutoCloseable {
Tensor outputTensor = outputTensors[index];
if (outputTensor == null) {
outputTensor =
- outputTensors[index] = Tensor.fromHandle(getOutputTensor(interpreterHandle, index));
+ outputTensors[index] =
+ Tensor.fromIndex(interpreterHandle, getOutputTensorIndex(interpreterHandle, index));
}
return outputTensor;
}
@@ -317,9 +329,9 @@ final class NativeInterpreterWrapper implements AutoCloseable {
private static native long allocateTensors(long interpreterHandle, long errorHandle);
- private static native long getInputTensor(long interpreterHandle, int inputIdx);
+ private static native int getInputTensorIndex(long interpreterHandle, int inputIdx);
- private static native long getOutputTensor(long interpreterHandle, int outputIdx);
+ private static native int getOutputTensorIndex(long interpreterHandle, int outputIdx);
private static native int getInputCount(long interpreterHandle);
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java
index f174178d98..6ca47aa3ed 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java
@@ -23,13 +23,26 @@ import java.util.Arrays;
/**
* A typed multi-dimensional array used in Tensorflow Lite.
*
- * <p>The native handle of a {@code Tensor} belongs to {@code NativeInterpreterWrapper}, thus not
- * needed to be closed here.
+ * <p>The native handle of a {@code Tensor} is managed by {@code NativeInterpreterWrapper}, and does
+ * not needed to be closed by the client. However, once the {@code NativeInterpreterWrapper} has
+ * been closed, the tensor handle will be invalidated.
*/
public final class Tensor {
- static Tensor fromHandle(long nativeHandle) {
- return new Tensor(nativeHandle);
+ /**
+ * Creates a Tensor wrapper from the provided interpreter instance and tensor index.
+ *
+ * <p>The caller is responsible for closing the created wrapper, and ensuring the provided
+ * native interpreter is valid until the tensor is closed.
+ */
+ static Tensor fromIndex(long nativeInterpreterHandle, int tensorIndex) {
+ return new Tensor(create(nativeInterpreterHandle, tensorIndex));
+ }
+
+ /** Disposes of any resources used by the Tensor wrapper. */
+ void close() {
+ delete(nativeHandle);
+ nativeHandle = 0;
}
/** Returns the {@link DataType} of elements stored in the Tensor. */
@@ -235,7 +248,7 @@ public final class Tensor {
return o instanceof ByteBuffer;
}
- private final long nativeHandle;
+ private long nativeHandle;
private final DataType dtype;
private int[] shapeCopy;
@@ -249,6 +262,10 @@ public final class Tensor {
return buffer(nativeHandle).order(ByteOrder.nativeOrder());
}
+ private static native long create(long interpreterHandle, int tensorIndex);
+
+ private static native void delete(long handle);
+
private static native ByteBuffer buffer(long handle);
private static native void writeDirectBuffer(long handle, ByteBuffer src);
diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
index abb7320bc5..4dc73fbcf8 100644
--- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
+++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
@@ -159,26 +159,20 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_allocateTensors(
}
}
-JNIEXPORT jlong JNICALL
-Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensor(JNIEnv* env,
- jclass clazz,
- jlong handle,
- jint index) {
+JNIEXPORT jint JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensorIndex(
+ JNIEnv* env, jclass clazz, jlong handle, jint input_index) {
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return 0;
- return reinterpret_cast<jlong>(
- interpreter->tensor(interpreter->inputs()[index]));
+ return interpreter->inputs()[input_index];
}
-JNIEXPORT jlong JNICALL
-Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensor(JNIEnv* env,
- jclass clazz,
- jlong handle,
- jint index) {
+JNIEXPORT jint JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensorIndex(
+ JNIEnv* env, jclass clazz, jlong handle, jint output_index) {
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return 0;
- return reinterpret_cast<jlong>(
- interpreter->tensor(interpreter->outputs()[index]));
+ return interpreter->outputs()[output_index];
}
JNIEXPORT jint JNICALL
diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
index aa809dff8a..f8f3e7028c 100644
--- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
+++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
@@ -46,25 +46,21 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_allocateTensors(
/*
* Class: org_tensorflow_lite_NativeInterpreterWrapper
- * Method: getInputTensor
- * Signature: (JI)J
+ * Method: getInputTensorIndex
+ * Signature: (JI)I
*/
-JNIEXPORT jlong JNICALL
-Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensor(JNIEnv* env,
- jclass clazz,
- jlong handle,
- jint index);
+JNIEXPORT jint JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensorIndex(
+ JNIEnv* env, jclass clazz, jlong handle, jint input_index);
/*
* Class: org_tensorflow_lite_NativeInterpreterWrapper
- * Method: getOutputTensor
- * Signature: (JI)J
+ * Method: getOutputTensorIndex
+ * Signature: (JI)I
*/
-JNIEXPORT jlong JNICALL
-Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensor(JNIEnv* env,
- jclass clazz,
- jlong handle,
- jint index);
+JNIEXPORT jint JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensorIndex(
+ JNIEnv* env, jclass clazz, jlong handle, jint output_index);
/*
* Class: org_tensorflow_lite_NativeInterpreterWrapper
diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc
index 7ff96a3172..d3378f5f14 100644
--- a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc
+++ b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc
@@ -16,17 +16,36 @@ limitations under the License.
#include "tensorflow/contrib/lite/java/src/main/native/tensor_jni.h"
#include <cstring>
#include <memory>
+#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/java/src/main/native/exception_jni.h"
namespace {
-TfLiteTensor* convertLongToTensor(JNIEnv* env, jlong handle) {
+// Convenience handle for obtaining a TfLiteTensor given an interpreter and
+// tensor index.
+//
+// Historically, the Java Tensor class used a TfLiteTensor pointer as its native
+// handle. However, this approach isn't generally safe, as the interpreter may
+// invalidate all TfLiteTensor* handles during inference or allocation.
+class TensorHandle {
+ public:
+ TensorHandle(tflite::Interpreter* interpreter, int tensor_index)
+ : interpreter_(interpreter), tensor_index_(tensor_index) {}
+
+ TfLiteTensor* tensor() const { return interpreter_->tensor(tensor_index_); }
+
+ private:
+ tflite::Interpreter* const interpreter_;
+ const int tensor_index_;
+};
+
+TfLiteTensor* GetTensorFromHandle(JNIEnv* env, jlong handle) {
if (handle == 0) {
throwException(env, kIllegalArgumentException,
"Internal error: Invalid handle to TfLiteTensor.");
return nullptr;
}
- return reinterpret_cast<TfLiteTensor*>(handle);
+ return reinterpret_cast<TensorHandle*>(handle)->tensor();
}
size_t elementByteSize(TfLiteType data_type) {
@@ -192,10 +211,23 @@ size_t writeMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type,
} // namespace
+JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_Tensor_create(
+ JNIEnv* env, jclass clazz, jlong interpreter_handle, jint tensor_index) {
+ tflite::Interpreter* interpreter =
+ reinterpret_cast<tflite::Interpreter*>(interpreter_handle);
+ return reinterpret_cast<jlong>(new TensorHandle(interpreter, tensor_index));
+}
+
+JNIEXPORT void JNICALL Java_org_tensorflow_lite_Tensor_delete(JNIEnv* env,
+ jclass clazz,
+ jlong handle) {
+ delete reinterpret_cast<TensorHandle*>(handle);
+}
+
JNIEXPORT jobject JNICALL Java_org_tensorflow_lite_Tensor_buffer(JNIEnv* env,
jclass clazz,
jlong handle) {
- TfLiteTensor* tensor = convertLongToTensor(env, handle);
+ TfLiteTensor* tensor = GetTensorFromHandle(env, handle);
if (tensor == nullptr) return nullptr;
if (tensor->data.raw == nullptr) {
throwException(env, kIllegalArgumentException,
@@ -208,7 +240,7 @@ JNIEXPORT jobject JNICALL Java_org_tensorflow_lite_Tensor_buffer(JNIEnv* env,
JNIEXPORT void JNICALL Java_org_tensorflow_lite_Tensor_writeDirectBuffer(
JNIEnv* env, jclass clazz, jlong handle, jobject src) {
- TfLiteTensor* tensor = convertLongToTensor(env, handle);
+ TfLiteTensor* tensor = GetTensorFromHandle(env, handle);
if (tensor == nullptr) return;
char* src_data_raw = static_cast<char*>(env->GetDirectBufferAddress(src));
@@ -226,7 +258,7 @@ Java_org_tensorflow_lite_Tensor_readMultiDimensionalArray(JNIEnv* env,
jclass clazz,
jlong handle,
jobject value) {
- TfLiteTensor* tensor = convertLongToTensor(env, handle);
+ TfLiteTensor* tensor = GetTensorFromHandle(env, handle);
if (tensor == nullptr) return;
int num_dims = tensor->dims->size;
if (num_dims == 0) {
@@ -243,7 +275,7 @@ Java_org_tensorflow_lite_Tensor_writeMultiDimensionalArray(JNIEnv* env,
jclass clazz,
jlong handle,
jobject src) {
- TfLiteTensor* tensor = convertLongToTensor(env, handle);
+ TfLiteTensor* tensor = GetTensorFromHandle(env, handle);
if (tensor == nullptr) return;
if (tensor->data.raw == nullptr) {
throwException(env, kIllegalArgumentException,
@@ -262,14 +294,14 @@ Java_org_tensorflow_lite_Tensor_writeMultiDimensionalArray(JNIEnv* env,
JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_dtype(JNIEnv* env,
jclass clazz,
jlong handle) {
- TfLiteTensor* tensor = convertLongToTensor(env, handle);
+ TfLiteTensor* tensor = GetTensorFromHandle(env, handle);
if (tensor == nullptr) return 0;
return static_cast<jint>(tensor->type);
}
JNIEXPORT jintArray JNICALL
Java_org_tensorflow_lite_Tensor_shape(JNIEnv* env, jclass clazz, jlong handle) {
- TfLiteTensor* tensor = convertLongToTensor(env, handle);
+ TfLiteTensor* tensor = GetTensorFromHandle(env, handle);
if (tensor == nullptr) return nullptr;
int num_dims = tensor->dims->size;
jintArray result = env->NewIntArray(num_dims);
@@ -280,7 +312,7 @@ Java_org_tensorflow_lite_Tensor_shape(JNIEnv* env, jclass clazz, jlong handle) {
JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_numBytes(JNIEnv* env,
jclass clazz,
jlong handle) {
- const TfLiteTensor* tensor = convertLongToTensor(env, handle);
+ const TfLiteTensor* tensor = GetTensorFromHandle(env, handle);
if (tensor == nullptr) return 0;
return static_cast<jint>(tensor->bytes);
}
diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h
index 2f73128bdf..c5e9690e9a 100644
--- a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h
+++ b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h
@@ -25,6 +25,23 @@ extern "C" {
/*
* Class: org_tensorflow_lite_Tensor
+ * Method: create
+ * Signature: (JI)J
+ */
+JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_Tensor_create(
+ JNIEnv* env, jclass clazz, jlong interpreter_handle, jint tensor_index);
+
+/*
+ * Class: org_tensorflow_lite_Tensor
+ * Method: delete
+ * Signature: (J)
+ */
+JNIEXPORT void JNICALL Java_org_tensorflow_lite_Tensor_delete(JNIEnv* env,
+ jclass clazz,
+ jlong handle);
+
+/*
+ * Class: org_tensorflow_lite_Tensor
* Method: buffer
* Signature: (J)Ljava/nio/ByteBuffer;
*/
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java
index 85ad393d89..56a38ea3e2 100644
--- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java
@@ -182,7 +182,7 @@ public final class TensorTest {
dataType = Tensor.dataTypeOf(testFloatArray);
assertThat(dataType).isEqualTo(DataType.FLOAT32);
float[][] testMultiDimArray = {testFloatArray, testFloatArray, testFloatArray};
- dataType = Tensor.dataTypeOf(testFloatArray);
+ dataType = Tensor.dataTypeOf(testMultiDimArray);
assertThat(dataType).isEqualTo(DataType.FLOAT32);
try {
double[] testDoubleArray = {0.783, 0.251};
@@ -238,4 +238,15 @@ public final class TensorTest {
assertThat(shape[1]).isEqualTo(3);
assertThat(shape[2]).isEqualTo(1);
}
+
+ @Test
+ public void testUseAfterClose() {
+ tensor.close();
+ try {
+ tensor.numBytes();
+ fail();
+ } catch (IllegalArgumentException e) {
+ // Expected failure.
+ }
+ }
}