aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/java
diff options
context:
space:
mode:
authorGravatar Jared Duke <jdduke@google.com>2018-07-09 14:11:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-09 14:14:53 -0700
commit09924e82d576011dd9509e952a756dac1e3b9c60 (patch)
tree55700ca0310e4709dd5cd8f425825fc648cd452f /tensorflow/contrib/lite/java
parent16ffd7c6cf05f0817d584acca90ea195b19b0530 (diff)
Implement Interpreter.run() in terms of Tensor APIs
PiperOrigin-RevId: 203826817
Diffstat (limited to 'tensorflow/contrib/lite/java')
-rw-r--r--tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java8
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java9
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java15
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java183
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java121
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/BUILD1
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/duration_utils_jni.cc38
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc307
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h79
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc123
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/tensor_jni.h40
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java8
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java182
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java77
-rw-r--r--tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java4
15 files changed, 568 insertions, 627 deletions
diff --git a/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java b/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java
index 56f3e7604a..1587c3c56f 100644
--- a/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java
+++ b/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java
@@ -127,12 +127,8 @@ public final class OvicClassifierTest {
try {
testResult = classifier.classifyByteBuffer(testImage);
fail();
- } catch (RuntimeException e) {
- assertThat(e)
- .hasMessageThat()
- .contains(
- "Failed to get input dimensions. 0-th input should have 49152 bytes, "
- + "but found 150528 bytes.");
+ } catch (IllegalArgumentException e) {
+ // Success.
}
}
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java
index 75334cd96e..94a1ec65d6 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java
@@ -27,10 +27,7 @@ enum DataType {
UINT8(3),
/** 64-bit signed integer. */
- INT64(4),
-
- /** A {@link ByteBuffer}. */
- BYTEBUFFER(999);
+ INT64(4);
private final int value;
@@ -69,8 +66,6 @@ enum DataType {
return 1;
case INT64:
return 8;
- case BYTEBUFFER:
- return 1;
}
throw new IllegalArgumentException(
"DataType error: DataType " + this + " is not supported yet");
@@ -87,8 +82,6 @@ enum DataType {
return "byte";
case INT64:
return "long";
- case BYTEBUFFER:
- return "ByteBuffer";
}
throw new IllegalArgumentException(
"DataType error: DataType " + this + " is not supported yet");
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
index 589fd6426f..7002f82677 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
@@ -165,20 +165,7 @@ public final class Interpreter implements AutoCloseable {
if (wrapper == null) {
throw new IllegalStateException("Internal error: The Interpreter has already been closed.");
}
- Tensor[] tensors = wrapper.run(inputs);
- if (outputs == null || tensors == null || outputs.size() > tensors.length) {
- throw new IllegalArgumentException("Output error: Outputs do not match with model outputs.");
- }
- final int size = tensors.length;
- for (Integer idx : outputs.keySet()) {
- if (idx == null || idx < 0 || idx >= size) {
- throw new IllegalArgumentException(
- String.format(
- "Output error: Invalid index of output %d (should be in range [0, %d))",
- idx, size));
- }
- tensors[idx].copyTo(outputs.get(idx));
- }
+ wrapper.run(inputs, outputs);
}
/**
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
index 80de88b6a1..072cb26bb2 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
@@ -19,6 +19,7 @@ import java.lang.reflect.Array;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
+import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
@@ -40,6 +41,8 @@ final class NativeInterpreterWrapper implements AutoCloseable {
modelHandle = createModel(modelPath, errorHandle);
interpreterHandle = createInterpreter(modelHandle, errorHandle, numThreads);
isMemoryAllocated = true;
+ inputTensors = new Tensor[getInputCount(interpreterHandle)];
+ outputTensors = new Tensor[getOutputCount(interpreterHandle)];
}
/**
@@ -72,6 +75,8 @@ final class NativeInterpreterWrapper implements AutoCloseable {
modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle);
interpreterHandle = createInterpreter(modelHandle, errorHandle, numThreads);
isMemoryAllocated = true;
+ inputTensors = new Tensor[getInputCount(interpreterHandle)];
+ outputTensors = new Tensor[getOutputCount(interpreterHandle)];
}
/** Releases resources associated with this {@code NativeInterpreterWrapper}. */
@@ -85,75 +90,63 @@ final class NativeInterpreterWrapper implements AutoCloseable {
inputsIndexes = null;
outputsIndexes = null;
isMemoryAllocated = false;
+ Arrays.fill(inputTensors, null);
+ Arrays.fill(outputTensors, null);
}
/** Sets inputs, runs model inference and returns outputs. */
- Tensor[] run(Object[] inputs) {
+ void run(Object[] inputs, Map<Integer, Object> outputs) {
+ inferenceDurationNanoseconds = -1;
if (inputs == null || inputs.length == 0) {
throw new IllegalArgumentException("Input error: Inputs should not be null or empty.");
}
- int[] dataTypes = new int[inputs.length];
- Object[] sizes = new Object[inputs.length];
- int[] numsOfBytes = new int[inputs.length];
+ if (outputs == null || outputs.isEmpty()) {
+ throw new IllegalArgumentException("Input error: Outputs should not be null or empty.");
+ }
+
+ // TODO(b/80431971): Remove implicit resize after deprecating multi-dimensional array inputs.
+ // Rather than forcing an immediate resize + allocation if an input's shape differs, we first
+ // flush all resizes, avoiding redundant allocations.
for (int i = 0; i < inputs.length; ++i) {
- DataType dataType = dataTypeOf(inputs[i]);
- dataTypes[i] = dataType.getNumber();
- if (dataType == DataType.BYTEBUFFER) {
- ByteBuffer buffer = (ByteBuffer) inputs[i];
- if (buffer == null || !buffer.isDirect() || buffer.order() != ByteOrder.nativeOrder()) {
- throw new IllegalArgumentException(
- "Input error: ByteBuffer should be a direct ByteBuffer that uses "
- + "ByteOrder.nativeOrder().");
- }
- numsOfBytes[i] = buffer.limit();
- sizes[i] = getInputDims(interpreterHandle, i, numsOfBytes[i]);
- } else if (isNonEmptyArray(inputs[i])) {
- int[] dims = shapeOf(inputs[i]);
- sizes[i] = dims;
- numsOfBytes[i] = dataType.elemByteSize() * numElements(dims);
- } else {
- throw new IllegalArgumentException(
- String.format(
- "Input error: %d-th element of the %d inputs is not an array or a ByteBuffer.",
- i, inputs.length));
+ Tensor tensor = getInputTensor(i);
+ int[] newShape = tensor.getInputShapeIfDifferent(inputs[i]);
+ if (newShape != null) {
+ resizeInput(i, newShape);
}
}
- inferenceDurationNanoseconds = -1;
- long[] outputsHandles =
- run(
- interpreterHandle,
- errorHandle,
- sizes,
- dataTypes,
- numsOfBytes,
- inputs,
- this,
- isMemoryAllocated);
- if (outputsHandles == null || outputsHandles.length == 0) {
- throw new IllegalStateException("Internal error: Interpreter has no outputs.");
+
+ if (!isMemoryAllocated) {
+ allocateTensors(interpreterHandle, errorHandle);
+ isMemoryAllocated = true;
+ // Allocation can trigger dynamic resizing of output tensors, so clear the
+ // output tensor cache.
+ Arrays.fill(outputTensors, null);
}
- isMemoryAllocated = true;
- Tensor[] outputs = new Tensor[outputsHandles.length];
- for (int i = 0; i < outputsHandles.length; ++i) {
- outputs[i] = Tensor.fromHandle(outputsHandles[i]);
+
+ for (int i = 0; i < inputs.length; ++i) {
+ getInputTensor(i).setTo(inputs[i]);
+ }
+
+ long inferenceStartNanos = System.nanoTime();
+ run(interpreterHandle, errorHandle);
+ long inferenceDurationNanoseconds = System.nanoTime() - inferenceStartNanos;
+
+ for (Map.Entry<Integer, Object> output : outputs.entrySet()) {
+ getOutputTensor(output.getKey()).copyTo(output.getValue());
}
- return outputs;
+
+ // Only set if the entire operation succeeds.
+ this.inferenceDurationNanoseconds = inferenceDurationNanoseconds;
}
- private static native long[] run(
- long interpreterHandle,
- long errorHandle,
- Object[] sizes,
- int[] dtypes,
- int[] numsOfBytes,
- Object[] values,
- NativeInterpreterWrapper wrapper,
- boolean memoryAllocated);
+ private static native boolean run(long interpreterHandle, long errorHandle);
/** Resizes dimensions of a specific input. */
void resizeInput(int idx, int[] dims) {
if (resizeInput(interpreterHandle, errorHandle, idx, dims)) {
isMemoryAllocated = false;
+ // Resizing will invalidate the Tensor's shape, so invalidate the Tensor handle.
+ inputTensors[idx] = null;
}
}
@@ -212,21 +205,6 @@ final class NativeInterpreterWrapper implements AutoCloseable {
}
}
- static int numElements(int[] shape) {
- if (shape == null) {
- return 0;
- }
- int n = 1;
- for (int i = 0; i < shape.length; i++) {
- n *= shape[i];
- }
- return n;
- }
-
- static boolean isNonEmptyArray(Object o) {
- return (o != null && o.getClass().isArray() && Array.getLength(o) != 0);
- }
-
/** Returns the type of the data. */
static DataType dataTypeOf(Object o) {
if (o != null) {
@@ -242,8 +220,6 @@ final class NativeInterpreterWrapper implements AutoCloseable {
return DataType.UINT8;
} else if (long.class.equals(c)) {
return DataType.INT64;
- } else if (ByteBuffer.class.isInstance(o)) {
- return DataType.BYTEBUFFER;
}
}
throw new IllegalArgumentException(
@@ -293,40 +269,55 @@ final class NativeInterpreterWrapper implements AutoCloseable {
}
/**
- * Gets the dimensions of an input. It throws IllegalArgumentException if input index is invalid.
+ * Gets the quantization zero point of an output.
+ *
+ * @throws IllegalArgumentException if the output index is invalid.
*/
- int[] getInputDims(int index) {
- return getInputDims(interpreterHandle, index, -1);
+ int getOutputQuantizationZeroPoint(int index) {
+ return getOutputQuantizationZeroPoint(interpreterHandle, index);
}
/**
- * Gets the dimensions of an input. If numBytes >= 0, it will check whether num of bytes match the
- * input.
+ * Gets the quantization scale of an output.
+ *
+ * @throws IllegalArgumentException if the output index is invalid.
*/
- private static native int[] getInputDims(long interpreterHandle, int inputIdx, int numBytes);
-
- /** Gets the type of an output. It throws IllegalArgumentException if output index is invalid. */
- String getOutputDataType(int index) {
- int type = getOutputDataType(interpreterHandle, index);
- return DataType.fromNumber(type).toStringName();
+ float getOutputQuantizationScale(int index) {
+ return getOutputQuantizationScale(interpreterHandle, index);
}
/**
- * Gets the quantization zero point of an output.
+ * Gets the input {@link Tensor} for the provided input index.
*
- * @throws IllegalArgumentExeption if the output index is invalid.
+ * @throws IllegalArgumentException if the input index is invalid.
*/
- int getOutputQuantizationZeroPoint(int index) {
- return getOutputQuantizationZeroPoint(interpreterHandle, index);
+ Tensor getInputTensor(int index) {
+ if (index < 0 || index >= inputTensors.length) {
+ throw new IllegalArgumentException("Invalid input Tensor index: " + index);
+ }
+ Tensor inputTensor = inputTensors[index];
+ if (inputTensor == null) {
+ inputTensor =
+ inputTensors[index] = Tensor.fromHandle(getInputTensor(interpreterHandle, index));
+ }
+ return inputTensor;
}
/**
- * Gets the quantization scale of an output.
+ * Gets the output {@link Tensor} for the provided output index.
*
- * @throws IllegalArgumentExeption if the output index is invalid.
+ * @throws IllegalArgumentException if the output index is invalid.
*/
- float getOutputQuantizationScale(int index) {
- return getOutputQuantizationScale(interpreterHandle, index);
+ Tensor getOutputTensor(int index) {
+ if (index < 0 || index >= outputTensors.length) {
+ throw new IllegalArgumentException("Invalid output Tensor index: " + index);
+ }
+ Tensor outputTensor = outputTensors[index];
+ if (outputTensor == null) {
+ outputTensor =
+ outputTensors[index] = Tensor.fromHandle(getOutputTensor(interpreterHandle, index));
+ }
+ return outputTensor;
}
private static native int getOutputDataType(long interpreterHandle, int outputIdx);
@@ -343,18 +334,30 @@ final class NativeInterpreterWrapper implements AutoCloseable {
private long modelHandle;
- private int inputSize;
-
private long inferenceDurationNanoseconds = -1;
private ByteBuffer modelByteBuffer;
+ // Lazily constructed maps of input and output names to input and output Tensor indexes.
private Map<String, Integer> inputsIndexes;
-
private Map<String, Integer> outputsIndexes;
+ // Lazily constructed and populated arrays of input and output Tensor wrappers.
+ private final Tensor[] inputTensors;
+ private final Tensor[] outputTensors;
+
private boolean isMemoryAllocated = false;
+ private static native long allocateTensors(long interpreterHandle, long errorHandle);
+
+ private static native long getInputTensor(long interpreterHandle, int inputIdx);
+
+ private static native long getOutputTensor(long interpreterHandle, int outputIdx);
+
+ private static native int getInputCount(long interpreterHandle);
+
+ private static native int getOutputCount(long interpreterHandle);
+
private static native String[] getInputNames(long interpreterHandle);
private static native String[] getOutputNames(long interpreterHandle);
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java
index b2a3e04c55..2c74c82417 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java
@@ -31,43 +31,122 @@ final class Tensor {
return new Tensor(nativeHandle);
}
+ /** Returns the {@link DataType} of elements stored in the Tensor. */
+ public DataType dataType() {
+ return dtype;
+ }
+
+ /** Returns the size, in bytes, of the tensor data. */
+ public int numBytes() {
+ return numBytes(nativeHandle);
+ }
+
+ /**
+ * Returns the <a href="https://www.tensorflow.org/resources/dims_types.html#shape">shape</a> of
+ * the Tensor, i.e., the sizes of each dimension.
+ *
+ * @return an array where the i-th element is the size of the i-th dimension of the tensor.
+ */
+ public int[] shape() {
+ return shapeCopy;
+ }
+
+ /**
+ * Copies the contents of the provided {@code src} object to the Tensor.
+ *
+ * <p>The {@code src} should either be a (multi-dimensional) array with a shape matching that of
+ * this tensor, or a {@link ByteByffer} of compatible primitive type with a matching flat size.
+ *
+ * @throws IllegalArgumentException if the tensor is a scalar or if {@code src} is not compatible
+ * with the tensor (for example, mismatched data types or shapes).
+ */
+ void setTo(Object src) {
+ throwExceptionIfTypeIsIncompatible(src);
+ if (isByteBuffer(src)) {
+ ByteBuffer srcBuffer = (ByteBuffer) src;
+ // For direct ByteBuffer instances we support zero-copy. Note that this assumes the caller
+ // retains ownership of the source buffer until inference has completed.
+ if (srcBuffer.isDirect() && srcBuffer.order() == ByteOrder.nativeOrder()) {
+ writeDirectBuffer(nativeHandle, srcBuffer);
+ } else {
+ buffer().put(srcBuffer);
+ }
+ return;
+ }
+ writeMultiDimensionalArray(nativeHandle, src);
+ }
+
/**
* Copies the contents of the tensor to {@code dst} and returns {@code dst}.
*
* @param dst the destination buffer, either an explicitly-typed array or a {@link ByteBuffer}.
* @throws IllegalArgumentException if {@code dst} is not compatible with the tensor (for example,
* mismatched data types or shapes).
- * @throws BufferOverflowException If {@code dst} is a ByteBuffer with insufficient space for the
- * data in this tensor.
*/
- <T> T copyTo(T dst) {
+ Object copyTo(Object dst) {
+ throwExceptionIfTypeIsIncompatible(dst);
if (dst instanceof ByteBuffer) {
ByteBuffer dstByteBuffer = (ByteBuffer) dst;
dstByteBuffer.put(buffer());
return dst;
}
- if (NativeInterpreterWrapper.dataTypeOf(dst) != dtype) {
+ readMultiDimensionalArray(nativeHandle, dst);
+ return dst;
+ }
+
+ /** Returns the provided buffer's shape if specified and different from this Tensor's shape. */
+ // TODO(b/80431971): Remove this method after deprecating multi-dimensional array inputs.
+ int[] getInputShapeIfDifferent(Object input) {
+ // Implicit resizes based on ByteBuffer capacity isn't supported, so short-circuit that path.
+ // The ByteBuffer's size will be validated against this Tensor's size in {@link #setTo(Object)}.
+ if (isByteBuffer(input)) {
+ return null;
+ }
+ int[] inputShape = NativeInterpreterWrapper.shapeOf(input);
+ if (Arrays.equals(shapeCopy, inputShape)) {
+ return null;
+ }
+ return inputShape;
+ }
+
+ private void throwExceptionIfTypeIsIncompatible(Object o) {
+ if (isByteBuffer(o)) {
+ ByteBuffer oBuffer = (ByteBuffer) o;
+ if (oBuffer.capacity() != numBytes()) {
+ throw new IllegalArgumentException(
+ String.format(
+ "Cannot convert between a TensorFlowLite buffer with %d bytes and a "
+ + "ByteBuffer with %d bytes.",
+ numBytes(), oBuffer.capacity()));
+ }
+ return;
+ }
+ DataType oType = NativeInterpreterWrapper.dataTypeOf(o);
+ if (oType != dtype) {
throw new IllegalArgumentException(
String.format(
- "Output error: Cannot convert an TensorFlowLite tensor with type %s to a Java "
- + "object of type %s (which is compatible with the TensorFlowLite type %s)",
- dtype, dst.getClass().getName(), NativeInterpreterWrapper.dataTypeOf(dst)));
+ "Cannot convert between a TensorFlowLite tensor with type %s and a Java "
+ + "object of type %s (which is compatible with the TensorFlowLite type %s).",
+ dtype, o.getClass().getName(), oType));
}
- int[] dstShape = NativeInterpreterWrapper.shapeOf(dst);
- if (!Arrays.equals(dstShape, shapeCopy)) {
+
+ int[] oShape = NativeInterpreterWrapper.shapeOf(o);
+ if (!Arrays.equals(oShape, shapeCopy)) {
throw new IllegalArgumentException(
String.format(
- "Output error: Shape of output target %s does not match with the shape of the "
- + "Tensor %s.",
- Arrays.toString(dstShape), Arrays.toString(shapeCopy)));
+ "Cannot copy between a TensorFlowLite tensor with shape %s and a Java object "
+ + "with shape %s.",
+ Arrays.toString(shapeCopy), Arrays.toString(oShape)));
}
- readMultiDimensionalArray(nativeHandle, dst);
- return dst;
}
- final long nativeHandle;
- final DataType dtype;
- final int[] shapeCopy;
+ private static boolean isByteBuffer(Object o) {
+ return o instanceof ByteBuffer;
+ }
+
+ private final long nativeHandle;
+ private final DataType dtype;
+ private final int[] shapeCopy;
private Tensor(long nativeHandle) {
this.nativeHandle = nativeHandle;
@@ -81,11 +160,17 @@ final class Tensor {
private static native ByteBuffer buffer(long handle);
+ private static native void writeDirectBuffer(long handle, ByteBuffer src);
+
private static native int dtype(long handle);
private static native int[] shape(long handle);
- private static native void readMultiDimensionalArray(long handle, Object value);
+ private static native int numBytes(long handle);
+
+ private static native void readMultiDimensionalArray(long handle, Object dst);
+
+ private static native void writeMultiDimensionalArray(long handle, Object src);
static {
TensorFlowLite.init();
diff --git a/tensorflow/contrib/lite/java/src/main/native/BUILD b/tensorflow/contrib/lite/java/src/main/native/BUILD
index 4399ed2025..4b4e1c21d8 100644
--- a/tensorflow/contrib/lite/java/src/main/native/BUILD
+++ b/tensorflow/contrib/lite/java/src/main/native/BUILD
@@ -11,7 +11,6 @@ licenses(["notice"]) # Apache 2.0
cc_library(
name = "native_framework_only",
srcs = [
- "duration_utils_jni.cc",
"exception_jni.cc",
"nativeinterpreterwrapper_jni.cc",
"tensor_jni.cc",
diff --git a/tensorflow/contrib/lite/java/src/main/native/duration_utils_jni.cc b/tensorflow/contrib/lite/java/src/main/native/duration_utils_jni.cc
deleted file mode 100644
index 0e08a04370..0000000000
--- a/tensorflow/contrib/lite/java/src/main/native/duration_utils_jni.cc
+++ /dev/null
@@ -1,38 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include <jni.h>
-#include <time.h>
-
-namespace tflite {
-
-// Gets the elapsed wall-clock timespec.
-timespec getCurrentTime() {
- timespec time;
- clock_gettime(CLOCK_MONOTONIC, &time);
- return time;
-}
-
-// Computes the time diff from two timespecs. Returns '-1' if 'stop' is earlier
-// than 'start'.
-jlong timespec_diff_nanoseconds(struct timespec* start, struct timespec* stop) {
- jlong result = stop->tv_sec - start->tv_sec;
- if (result < 0) return -1;
- result = 1000000000 * result + (stop->tv_nsec - start->tv_nsec);
- if (result < 0) return -1;
- return result;
-}
-
-} // namespace tflite
diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
index 31f7b58fbc..e2c1edd9af 100644
--- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
+++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
@@ -16,9 +16,6 @@ limitations under the License.
#include "tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h"
namespace {
-const int kByteBufferValue = 999;
-const int kBufferSize = 256;
-
tflite::Interpreter* convertLongToInterpreter(JNIEnv* env, jlong handle) {
if (handle == 0) {
throwException(env, kIllegalArgumentException,
@@ -62,22 +59,6 @@ std::vector<int> convertJIntArrayToVector(JNIEnv* env, jintArray inputs) {
return outputs;
}
-bool isByteBuffer(jint data_type) { return data_type == kByteBufferValue; }
-
-TfLiteType resolveDataType(jint data_type) {
- switch (data_type) {
- case 1:
- return kTfLiteFloat32;
- case 2:
- return kTfLiteInt32;
- case 3:
- return kTfLiteUInt8;
- case 4:
- return kTfLiteInt64;
- default:
- return kTfLiteNoType;
- }
-}
int getDataType(TfLiteType data_type) {
switch (data_type) {
@@ -108,64 +89,6 @@ void printDims(char* buffer, int max_size, int* dims, int num_dims) {
}
}
-TfLiteStatus checkInputs(JNIEnv* env, tflite::Interpreter* interpreter,
- const int input_size, jintArray data_types,
- jintArray nums_of_bytes, jobjectArray values,
- jobjectArray sizes) {
- if (input_size != interpreter->inputs().size()) {
- throwException(env, kIllegalArgumentException,
- "Input error: Expected num of inputs is %d but got %d",
- interpreter->inputs().size(), input_size);
- return kTfLiteError;
- }
- if (input_size != env->GetArrayLength(data_types) ||
- input_size != env->GetArrayLength(nums_of_bytes) ||
- input_size != env->GetArrayLength(values)) {
- throwException(env, kIllegalArgumentException,
- "Internal error: Arrays in arguments should be of the same "
- "length, but got %d sizes, %d data_types, %d nums_of_bytes, "
- "and %d values",
- input_size, env->GetArrayLength(data_types),
- env->GetArrayLength(nums_of_bytes),
- env->GetArrayLength(values));
- return kTfLiteError;
- }
- for (int i = 0; i < input_size; ++i) {
- int input_idx = interpreter->inputs()[i];
- TfLiteTensor* target = interpreter->tensor(input_idx);
- jintArray dims =
- static_cast<jintArray>(env->GetObjectArrayElement(sizes, i));
- int num_dims = static_cast<int>(env->GetArrayLength(dims));
- if (target->dims->size != num_dims) {
- throwException(env, kIllegalArgumentException,
- "Input error: %d-th input should have %d dimensions, but "
- "found %d dimensions",
- i, target->dims->size, num_dims);
- return kTfLiteError;
- }
- jint* ptr = env->GetIntArrayElements(dims, nullptr);
- for (int j = 1; j < num_dims; ++j) {
- if (target->dims->data[j] != ptr[j]) {
- std::unique_ptr<char[]> expected_dims(new char[kBufferSize]);
- std::unique_ptr<char[]> obtained_dims(new char[kBufferSize]);
- printDims(expected_dims.get(), kBufferSize, target->dims->data,
- num_dims);
- printDims(obtained_dims.get(), kBufferSize, ptr, num_dims);
- throwException(env, kIllegalArgumentException,
- "Input error: %d-th input dimension should be [%s], but "
- "found [%s]",
- i, expected_dims.get(), obtained_dims.get());
- env->ReleaseIntArrayElements(dims, ptr, JNI_ABORT);
- return kTfLiteError;
- }
- }
- env->ReleaseIntArrayElements(dims, ptr, JNI_ABORT);
- env->DeleteLocalRef(dims);
- if (env->ExceptionCheck()) return kTfLiteError;
- }
- return kTfLiteOk;
-}
-
// Checks whether there is any difference between dimensions of a tensor and a
// given dimensions. Returns true if there is difference, else false.
bool areDimsDifferent(JNIEnv* env, TfLiteTensor* tensor, jintArray dims) {
@@ -188,74 +111,6 @@ bool areDimsDifferent(JNIEnv* env, TfLiteTensor* tensor, jintArray dims) {
return false;
}
-bool areInputDimensionsTheSame(JNIEnv* env, tflite::Interpreter* interpreter,
- int input_size, jobjectArray sizes) {
- if (interpreter->inputs().size() != input_size) {
- return false;
- }
- for (int i = 0; i < input_size; ++i) {
- int input_idx = interpreter->inputs()[i];
- jintArray dims =
- static_cast<jintArray>(env->GetObjectArrayElement(sizes, i));
- TfLiteTensor* target = interpreter->tensor(input_idx);
- if (areDimsDifferent(env, target, dims)) return false;
- env->DeleteLocalRef(dims);
- if (env->ExceptionCheck()) return false;
- }
- return true;
-}
-
-TfLiteStatus resizeInputs(JNIEnv* env, tflite::Interpreter* interpreter,
- int input_size, jobjectArray sizes) {
- for (int i = 0; i < input_size; ++i) {
- int input_idx = interpreter->inputs()[i];
- jintArray dims =
- static_cast<jintArray>(env->GetObjectArrayElement(sizes, i));
- TfLiteStatus status = interpreter->ResizeInputTensor(
- input_idx, convertJIntArrayToVector(env, dims));
- if (status != kTfLiteOk) {
- return status;
- }
- env->DeleteLocalRef(dims);
- if (env->ExceptionCheck()) return kTfLiteError;
- }
- return kTfLiteOk;
-}
-
-TfLiteStatus setInputs(JNIEnv* env, tflite::Interpreter* interpreter,
- int input_size, jintArray data_types,
- jintArray nums_of_bytes, jobjectArray values) {
- jint* data_type = env->GetIntArrayElements(data_types, nullptr);
- jint* num_bytes = env->GetIntArrayElements(nums_of_bytes, nullptr);
- for (int i = 0; i < input_size; ++i) {
- int input_idx = interpreter->inputs()[i];
- TfLiteTensor* target = interpreter->tensor(input_idx);
- jobject value = env->GetObjectArrayElement(values, i);
- bool is_byte_buffer = isByteBuffer(data_type[i]);
- if (is_byte_buffer) {
- writeByteBuffer(env, value, &(target->data.raw),
- static_cast<int>(num_bytes[i]));
- } else {
- TfLiteType type = resolveDataType(data_type[i]);
- if (type != target->type) {
- throwException(env, kIllegalArgumentException,
- "Input error: DataType (%d) of input data does not "
- "match with the DataType (%d) of model inputs.",
- type, target->type);
- return kTfLiteError;
- }
- writeMultiDimensionalArray(env, value, target->type, target->dims->size,
- &(target->data.raw),
- static_cast<int>(num_bytes[i]));
- }
- env->DeleteLocalRef(value);
- if (env->ExceptionCheck()) return kTfLiteError;
- }
- env->ReleaseIntArrayElements(data_types, data_type, JNI_ABORT);
- env->ReleaseIntArrayElements(nums_of_bytes, num_bytes, JNI_ABORT);
- return kTfLiteOk;
-}
-
// TODO(yichengfan): evaluate the benefit to use tflite verifier.
bool VerifyModel(const void* buf, size_t len) {
flatbuffers::Verifier verifier(static_cast<const uint8_t*>(buf), len);
@@ -287,6 +142,63 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputNames(JNIEnv* env,
return names;
}
+JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_allocateTensors(
+ JNIEnv* env, jclass clazz, jlong handle, jlong error_handle) {
+ tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+ if (interpreter == nullptr) return;
+ BufferErrorReporter* error_reporter =
+ convertLongToErrorReporter(env, error_handle);
+ if (error_reporter == nullptr) return;
+
+ if (interpreter->AllocateTensors() != kTfLiteOk) {
+ throwException(env, kNullPointerException,
+ "Internal error: Cannot allocate memory for the interpreter:"
+ " %s",
+ error_reporter->CachedErrorMessage());
+ }
+}
+
+JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensor(JNIEnv* env,
+ jclass clazz,
+ jlong handle,
+ jint index) {
+ tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+ if (interpreter == nullptr) return 0;
+ return reinterpret_cast<jlong>(
+ interpreter->tensor(interpreter->inputs()[index]));
+}
+
+JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensor(JNIEnv* env,
+ jclass clazz,
+ jlong handle,
+ jint index) {
+ tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+ if (interpreter == nullptr) return 0;
+ return reinterpret_cast<jlong>(
+ interpreter->tensor(interpreter->outputs()[index]));
+}
+
+JNIEXPORT jint JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputCount(JNIEnv* env,
+ jclass clazz,
+ jlong handle) {
+ tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+ if (interpreter == nullptr) return 0;
+ return static_cast<jint>(interpreter->inputs().size());
+}
+
+JNIEXPORT jint JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputCount(JNIEnv* env,
+ jclass clazz,
+ jlong handle) {
+ tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+ if (interpreter == nullptr) return 0;
+ return static_cast<jint>(interpreter->outputs().size());
+}
+
JNIEXPORT jobjectArray JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputNames(JNIEnv* env,
jclass clazz,
@@ -434,114 +346,21 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter(
}
// Sets inputs, runs inference, and returns outputs as long handles.
-JNIEXPORT jlongArray JNICALL
-Java_org_tensorflow_lite_NativeInterpreterWrapper_run(
- JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle,
- jobjectArray sizes, jintArray data_types, jintArray nums_of_bytes,
- jobjectArray values, jobject wrapper, jboolean memory_allocated) {
+JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_run(
+ JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle) {
tflite::Interpreter* interpreter =
convertLongToInterpreter(env, interpreter_handle);
- if (interpreter == nullptr) return nullptr;
+ if (interpreter == nullptr) return;
BufferErrorReporter* error_reporter =
convertLongToErrorReporter(env, error_handle);
- if (error_reporter == nullptr) return nullptr;
- const int input_size = env->GetArrayLength(sizes);
- // validates inputs
- TfLiteStatus status = checkInputs(env, interpreter, input_size, data_types,
- nums_of_bytes, values, sizes);
- if (status != kTfLiteOk) return nullptr;
- if (!memory_allocated ||
- !areInputDimensionsTheSame(env, interpreter, input_size, sizes)) {
- // resizes inputs
- status = resizeInputs(env, interpreter, input_size, sizes);
- if (status != kTfLiteOk) {
- throwException(env, kNullPointerException,
- "Internal error: Can not resize the input: %s",
- error_reporter->CachedErrorMessage());
- return nullptr;
- }
- // allocates memory
- status = interpreter->AllocateTensors();
- if (status != kTfLiteOk) {
- throwException(env, kNullPointerException,
- "Internal error: Can not allocate memory for the given "
- "inputs: %s",
- error_reporter->CachedErrorMessage());
- return nullptr;
- }
- }
- // sets inputs
- status = setInputs(env, interpreter, input_size, data_types, nums_of_bytes,
- values);
- if (status != kTfLiteOk) return nullptr;
- timespec beforeInference = ::tflite::getCurrentTime();
- // runs inference
+ if (error_reporter == nullptr) return;
+
if (interpreter->Invoke() != kTfLiteOk) {
throwException(env, kIllegalArgumentException,
"Internal error: Failed to run on the given Interpreter: %s",
error_reporter->CachedErrorMessage());
- return nullptr;
- }
- timespec afterInference = ::tflite::getCurrentTime();
- jclass wrapper_clazz = env->GetObjectClass(wrapper);
- jfieldID fid =
- env->GetFieldID(wrapper_clazz, "inferenceDurationNanoseconds", "J");
- if (env->ExceptionCheck()) {
- env->ExceptionClear();
- } else if (fid != nullptr) {
- env->SetLongField(
- wrapper, fid,
- ::tflite::timespec_diff_nanoseconds(&beforeInference, &afterInference));
- }
- // returns outputs
- const std::vector<int>& results = interpreter->outputs();
- if (results.empty()) {
- throwException(
- env, kIllegalArgumentException,
- "Internal error: The Interpreter does not have any outputs.");
- return nullptr;
- }
- jlongArray outputs = env->NewLongArray(results.size());
- size_t size = results.size();
- for (int i = 0; i < size; ++i) {
- TfLiteTensor* source = interpreter->tensor(results[i]);
- jlong output = reinterpret_cast<jlong>(source);
- env->SetLongArrayRegion(outputs, i, 1, &output);
- }
- return outputs;
-}
-
-JNIEXPORT jintArray JNICALL
-Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputDims(
- JNIEnv* env, jclass clazz, jlong handle, jint input_idx, jint num_bytes) {
- tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
- if (interpreter == nullptr) return nullptr;
- const int idx = static_cast<int>(input_idx);
- if (input_idx < 0 || input_idx >= interpreter->inputs().size()) {
- throwException(env, kIllegalArgumentException,
- "Input error: Out of range: Failed to get %d-th input out of"
- " %d inputs",
- input_idx, interpreter->inputs().size());
- return nullptr;
- }
- TfLiteTensor* target = interpreter->tensor(interpreter->inputs()[idx]);
- int size = target->dims->size;
- if (num_bytes >= 0) { // verifies num of bytes matches if num_bytes if valid.
- int expected_num_bytes = elementByteSize(target->type);
- for (int i = 0; i < size; ++i) {
- expected_num_bytes *= target->dims->data[i];
- }
- if (num_bytes != expected_num_bytes) {
- throwException(env, kIllegalArgumentException,
- "Input error: Failed to get input dimensions. %d-th input "
- "should have %d bytes, but found %d bytes.",
- idx, expected_num_bytes, num_bytes);
- return nullptr;
- }
+ return;
}
- jintArray outputs = env->NewIntArray(size);
- env->SetIntArrayRegion(outputs, 0, size, &(target->dims->data[0]));
- return outputs;
}
JNIEXPORT jint JNICALL
diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
index 128ece4981..618fba480e 100644
--- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
+++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
@@ -29,9 +29,6 @@ limitations under the License.
namespace tflite {
// This is to be provided at link-time by a library.
extern std::unique_ptr<OpResolver> CreateOpResolver();
-extern timespec getCurrentTime();
-extern jlong timespec_diff_nanoseconds(struct timespec* start,
- struct timespec* stop);
} // namespace tflite
#ifdef __cplusplus
@@ -40,6 +37,57 @@ extern "C" {
/*
* Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method: allocateTensors
+ * Signature: (JJ)V
+ */
+JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_allocateTensors(
+ JNIEnv* env, jclass clazz, jlong handle, jlong error_handle);
+
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method: getInputTensor
+ * Signature: (JI)J
+ */
+JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensor(JNIEnv* env,
+ jclass clazz,
+ jlong handle,
+ jint index);
+
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method: getOutputTensor
+ * Signature: (JI)J
+ */
+JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensor(JNIEnv* env,
+ jclass clazz,
+ jlong handle,
+ jint index);
+
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method: getInputCount
+ * Signature: (J)I
+ */
+JNIEXPORT jint JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputCount(JNIEnv* env,
+ jclass clazz,
+ jlong handle);
+
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method: getOutputCount
+ * Signature: (J)I
+ */
+JNIEXPORT jint JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputCount(JNIEnv* env,
+ jclass clazz,
+ jlong handle);
+
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
* Method:
* Signature: (J)[Ljava/lang/Object;
*/
@@ -118,28 +166,11 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter(
/*
* Class: org_tensorflow_lite_NativeInterpreterWrapper
- * Method:
- * Signature:
- * (JJ[Ljava/lang/Object;[I[I[Ljava/lang/Object;Ljava/lang/Object;Z)[J
- */
-JNIEXPORT jlongArray JNICALL
-Java_org_tensorflow_lite_NativeInterpreterWrapper_run(
- JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle,
- jobjectArray sizes, jintArray data_types, jintArray nums_of_bytes,
- jobjectArray values, jobject wrapper, jboolean memory_allocated);
-
-/*
- * Class: org_tensorflow_lite_NativeInterpreterWrapper
- * Method:
- * Signature: (JII)[I
- *
- * Gets input dimensions. If num_bytes is non-negative, it will check whether
- * num_bytes matches num of bytes required by the input, and return null and
- * throw IllegalArgumentException if not.
+ * Method: run
+ * Signature: (JJ)V
*/
-JNIEXPORT jintArray JNICALL
-Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputDims(
- JNIEnv* env, jclass clazz, jlong handle, jint input_idx, jint num_bytes);
+JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_run(
+ JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle);
/*
* Class: org_tensorflow_lite_NativeInterpreterWrapper
diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc
index 08b4d04280..7ff96a3172 100644
--- a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc
+++ b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc
@@ -29,6 +29,35 @@ TfLiteTensor* convertLongToTensor(JNIEnv* env, jlong handle) {
return reinterpret_cast<TfLiteTensor*>(handle);
}
+size_t elementByteSize(TfLiteType data_type) {
+ // The code in this file makes the assumption that the
+ // TensorFlow TF_DataTypes and the Java primitive types
+ // have the same byte sizes. Validate that:
+ switch (data_type) {
+ case kTfLiteFloat32:
+ static_assert(sizeof(jfloat) == 4,
+ "Interal error: Java float not compatible with "
+ "kTfLiteFloat");
+ return 4;
+ case kTfLiteInt32:
+ static_assert(sizeof(jint) == 4,
+ "Interal error: Java int not compatible with kTfLiteInt");
+ return 4;
+ case kTfLiteUInt8:
+ static_assert(sizeof(jbyte) == 1,
+ "Interal error: Java byte not compatible with "
+ "kTfLiteUInt8");
+ return 1;
+ case kTfLiteInt64:
+ static_assert(sizeof(jlong) == 8,
+ "Interal error: Java long not compatible with "
+ "kTfLiteInt64");
+ return 8;
+ default:
+ return 0;
+ }
+}
+
size_t writeOneDimensionalArray(JNIEnv* env, jobject object, TfLiteType type,
void* dst, size_t dst_size) {
jarray array = static_cast<jarray>(object);
@@ -141,48 +170,6 @@ size_t readMultiDimensionalArray(JNIEnv* env, TfLiteType data_type, char* src,
}
}
-} // namespace
-
-size_t elementByteSize(TfLiteType data_type) {
- // The code in this file makes the assumption that the
- // TensorFlow TF_DataTypes and the Java primitive types
- // have the same byte sizes. Validate that:
- switch (data_type) {
- case kTfLiteFloat32:
- static_assert(sizeof(jfloat) == 4,
- "Interal error: Java float not compatible with "
- "kTfLiteFloat");
- return 4;
- case kTfLiteInt32:
- static_assert(sizeof(jint) == 4,
- "Interal error: Java int not compatible with kTfLiteInt");
- return 4;
- case kTfLiteUInt8:
- static_assert(sizeof(jbyte) == 1,
- "Interal error: Java byte not compatible with "
- "kTfLiteUInt8");
- return 1;
- case kTfLiteInt64:
- static_assert(sizeof(jlong) == 8,
- "Interal error: Java long not compatible with "
- "kTfLiteInt64");
- return 8;
- default:
- return 0;
- }
-}
-
-size_t writeByteBuffer(JNIEnv* env, jobject object, char** dst, int dst_size) {
- char* buf = static_cast<char*>(env->GetDirectBufferAddress(object));
- if (!buf) {
- throwException(env, kIllegalArgumentException,
- "Input ByteBuffer is not a direct buffer");
- return 0;
- }
- *dst = buf;
- return dst_size;
-}
-
size_t writeMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type,
int dims_left, char** dst, int dst_size) {
if (dims_left <= 1) {
@@ -203,16 +190,37 @@ size_t writeMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type,
}
}
+} // namespace
+
JNIEXPORT jobject JNICALL Java_org_tensorflow_lite_Tensor_buffer(JNIEnv* env,
jclass clazz,
jlong handle) {
TfLiteTensor* tensor = convertLongToTensor(env, handle);
if (tensor == nullptr) return nullptr;
-
+ if (tensor->data.raw == nullptr) {
+ throwException(env, kIllegalArgumentException,
+ "Internal error: Tensor hasn't been allocated.");
+ return nullptr;
+ }
return env->NewDirectByteBuffer(static_cast<void*>(tensor->data.raw),
static_cast<jlong>(tensor->bytes));
}
+JNIEXPORT void JNICALL Java_org_tensorflow_lite_Tensor_writeDirectBuffer(
+ JNIEnv* env, jclass clazz, jlong handle, jobject src) {
+ TfLiteTensor* tensor = convertLongToTensor(env, handle);
+ if (tensor == nullptr) return;
+
+ char* src_data_raw = static_cast<char*>(env->GetDirectBufferAddress(src));
+ if (!src_data_raw) {
+ throwException(env, kIllegalArgumentException,
+ "Input ByteBuffer is not a direct buffer");
+ return;
+ }
+
+ tensor->data.raw = src_data_raw;
+}
+
JNIEXPORT void JNICALL
Java_org_tensorflow_lite_Tensor_readMultiDimensionalArray(JNIEnv* env,
jclass clazz,
@@ -230,6 +238,27 @@ Java_org_tensorflow_lite_Tensor_readMultiDimensionalArray(JNIEnv* env,
num_dims, static_cast<jarray>(value));
}
+JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_Tensor_writeMultiDimensionalArray(JNIEnv* env,
+ jclass clazz,
+ jlong handle,
+ jobject src) {
+ TfLiteTensor* tensor = convertLongToTensor(env, handle);
+ if (tensor == nullptr) return;
+ if (tensor->data.raw == nullptr) {
+ throwException(env, kIllegalArgumentException,
+ "Internal error: Target Tensor hasn't been allocated.");
+ return;
+ }
+ if (tensor->dims->size == 0) {
+ throwException(env, kIllegalArgumentException,
+ "Internal error: Cannot copy empty/scalar Tensors.");
+ return;
+ }
+ writeMultiDimensionalArray(env, src, tensor->type, tensor->dims->size,
+ &tensor->data.raw, tensor->bytes);
+}
+
JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_dtype(JNIEnv* env,
jclass clazz,
jlong handle) {
@@ -247,3 +276,11 @@ Java_org_tensorflow_lite_Tensor_shape(JNIEnv* env, jclass clazz, jlong handle) {
env->SetIntArrayRegion(result, 0, num_dims, tensor->dims->data);
return result;
}
+
+JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_numBytes(JNIEnv* env,
+ jclass clazz,
+ jlong handle) {
+ const TfLiteTensor* tensor = convertLongToTensor(env, handle);
+ if (tensor == nullptr) return 0;
+ return static_cast<jint>(tensor->bytes);
+}
diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h
index 9ba95d9ac4..06e2546af8 100644
--- a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h
+++ b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h
@@ -34,6 +34,14 @@ JNIEXPORT jobject JNICALL Java_org_tensorflow_lite_Tensor_buffer(JNIEnv* env,
/*
* Class: org_tensorflow_lite_Tensor
+ * Method: writeDirectBuffer
+ * Signature: (JLjava/nio/ByteBuffer;)
+ */
+JNIEXPORT void JNICALL Java_org_tensorflow_lite_Tensor_writeDirectBuffer(
+ JNIEnv* env, jclass clazz, jlong handle, jobject src);
+
+/*
+ * Class: org_tensorflow_lite_Tensor
* Method: dtype
* Signature: (J)I
*/
@@ -52,6 +60,15 @@ JNIEXPORT jintArray JNICALL Java_org_tensorflow_lite_Tensor_shape(JNIEnv* env,
/*
* Class: org_tensorflow_lite_Tensor
+ * Method: numBytes
+ * Signature: (J)I
+ */
+JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_numBytes(JNIEnv* env,
+ jclass clazz,
+ jlong handle);
+
+/*
+ * Class: org_tensorflow_lite_Tensor
* Method: readMultiDimensionalArray
* Signature: (JLjava/lang/Object;)
*/
@@ -59,23 +76,18 @@ JNIEXPORT void JNICALL
Java_org_tensorflow_lite_Tensor_readMultiDimensionalArray(JNIEnv* env,
jclass clazz,
jlong handle,
- jobject value);
+ jobject dst);
/*
- * Finds the size of each data type.
- */
-size_t elementByteSize(TfLiteType data_type);
-
-/*
- * Writes data of a ByteBuffer into dest.
- */
-size_t writeByteBuffer(JNIEnv* env, jobject object, char** dst, int dst_size);
-
-/*
- * Writes a multi-dimensional array into dest.
+ * Class: org_tensorflow_lite_Tensor
+ * Method: writeMultidimensionalArray
+ * Signature: (JLjava/lang/Object;)
*/
-size_t writeMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type,
- int dims_left, char** dst, int dst_size);
+JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_Tensor_writeMultiDimensionalArray(JNIEnv* env,
+ jclass clazz,
+ jlong handle,
+ jobject src);
#ifdef __cplusplus
} // extern "C"
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
index 42096ef9a3..d66a73db94 100644
--- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
@@ -221,7 +221,9 @@ public final class InterpreterTest {
assertThat(e)
.hasMessageThat()
.contains(
- "DataType (2) of input data does not match with the DataType (1) of model inputs.");
+ "Cannot convert between a TensorFlowLite tensor with type "
+ + "FLOAT32 and a Java object of type [[[[I (which is compatible with the"
+ + " TensorFlowLite type INT32)");
}
interpreter.close();
}
@@ -241,8 +243,8 @@ public final class InterpreterTest {
assertThat(e)
.hasMessageThat()
.contains(
- "Cannot convert an TensorFlowLite tensor with type "
- + "FLOAT32 to a Java object of type [[[[I (which is compatible with the"
+ "Cannot convert between a TensorFlowLite tensor with type "
+ + "FLOAT32 and a Java object of type [[[[I (which is compatible with the"
+ " TensorFlowLite type INT32)");
}
interpreter.close();
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
index 029e5853e2..46bdecf443 100644
--- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
@@ -20,6 +20,8 @@ import static org.junit.Assert.fail;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
+import java.util.HashMap;
+import java.util.Map;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -101,10 +103,10 @@ public final class NativeInterpreterWrapperTest {
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
- Tensor[] outputs = wrapper.run(inputs);
- assertThat(outputs.length).isEqualTo(1);
float[][][][] parsedOutputs = new float[2][8][8][3];
- outputs[0].copyTo(parsedOutputs);
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
+ wrapper.run(inputs, outputs);
float[] outputOneD = parsedOutputs[0][0][0];
float[] expected = {3.69f, -19.62f, 23.43f};
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
@@ -119,11 +121,11 @@ public final class NativeInterpreterWrapperTest {
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
- Tensor[] outputs = wrapper.run(inputs);
- assertThat(outputs).hasLength(1);
ByteBuffer parsedOutput =
ByteBuffer.allocateDirect(2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder());
- outputs[0].copyTo(parsedOutput);
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutput);
+ wrapper.run(inputs, outputs);
float[] outputOneD = {
parsedOutput.getFloat(0), parsedOutput.getFloat(4), parsedOutput.getFloat(8)
};
@@ -140,17 +142,16 @@ public final class NativeInterpreterWrapperTest {
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
- Tensor[] outputs = wrapper.run(inputs);
- assertThat(outputs.length).isEqualTo(1);
float[][][][] parsedOutputs = new float[2][8][8][3];
- outputs[0].copyTo(parsedOutputs);
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
+ wrapper.run(inputs, outputs);
float[] outputOneD = parsedOutputs[0][0][0];
float[] expected = {3.69f, -19.62f, 23.43f};
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
- outputs = wrapper.run(inputs);
- assertThat(outputs.length).isEqualTo(1);
parsedOutputs = new float[2][8][8][3];
- outputs[0].copyTo(parsedOutputs);
+ outputs.put(0, parsedOutputs);
+ wrapper.run(inputs, outputs);
outputOneD = parsedOutputs[0][0][0];
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
wrapper.close();
@@ -164,10 +165,10 @@ public final class NativeInterpreterWrapperTest {
int[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
int[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
- Tensor[] outputs = wrapper.run(inputs);
- assertThat(outputs.length).isEqualTo(1);
int[][][][] parsedOutputs = new int[2][4][4][12];
- outputs[0].copyTo(parsedOutputs);
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
+ wrapper.run(inputs, outputs);
int[] outputOneD = parsedOutputs[0][0][0];
int[] expected = {3, 7, -4, 3, 7, -4, 3, 7, -4, 3, 7, -4};
assertThat(outputOneD).isEqualTo(expected);
@@ -182,10 +183,10 @@ public final class NativeInterpreterWrapperTest {
long[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
long[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
- Tensor[] outputs = wrapper.run(inputs);
- assertThat(outputs.length).isEqualTo(1);
long[][][][] parsedOutputs = new long[2][4][4][12];
- outputs[0].copyTo(parsedOutputs);
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
+ wrapper.run(inputs, outputs);
long[] outputOneD = parsedOutputs[0][0][0];
long[] expected = {-892834092L, 923423L, 2123918239018L, -892834092L, 923423L, 2123918239018L,
-892834092L, 923423L, 2123918239018L, -892834092L, 923423L, 2123918239018L};
@@ -203,10 +204,10 @@ public final class NativeInterpreterWrapperTest {
Object[] inputs = {fourD};
int[] inputDims = {2, 8, 8, 3};
wrapper.resizeInput(0, inputDims);
- Tensor[] outputs = wrapper.run(inputs);
- assertThat(outputs.length).isEqualTo(1);
byte[][][][] parsedOutputs = new byte[2][4][4][12];
- outputs[0].copyTo(parsedOutputs);
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
+ wrapper.run(inputs, outputs);
byte[] outputOneD = parsedOutputs[0][0][0];
byte[] expected = {(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0,
(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0};
@@ -229,13 +230,14 @@ public final class NativeInterpreterWrapperTest {
}
}
}
+ bbuf.rewind();
Object[] inputs = {bbuf};
int[] inputDims = {2, 8, 8, 3};
wrapper.resizeInput(0, inputDims);
- Tensor[] outputs = wrapper.run(inputs);
- assertThat(outputs.length).isEqualTo(1);
byte[][][][] parsedOutputs = new byte[2][4][4][12];
- outputs[0].copyTo(parsedOutputs);
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
+ wrapper.run(inputs, outputs);
byte[] outputOneD = parsedOutputs[0][0][0];
byte[] expected = {
(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0,
@@ -261,21 +263,22 @@ public final class NativeInterpreterWrapperTest {
}
}
Object[] inputs = {bbuf};
+ float[][][][] parsedOutputs = new float[4][8][8][3];
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
try {
- wrapper.run(inputs);
+ wrapper.run(inputs, outputs);
fail();
} catch (IllegalArgumentException e) {
assertThat(e)
.hasMessageThat()
.contains(
- "Failed to get input dimensions. 0-th input should have 768 bytes, but found 3072 bytes");
+ "Cannot convert between a TensorFlowLite buffer with 768 bytes and a "
+ + "ByteBuffer with 3072 bytes.");
}
int[] inputDims = {4, 8, 8, 3};
wrapper.resizeInput(0, inputDims);
- Tensor[] outputs = wrapper.run(inputs);
- assertThat(outputs.length).isEqualTo(1);
- float[][][][] parsedOutputs = new float[4][8][8][3];
- outputs[0].copyTo(parsedOutputs);
+ wrapper.run(inputs, outputs);
float[] outputOneD = parsedOutputs[0][0][0];
float[] expected = {3.69f, -19.62f, 23.43f};
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
@@ -288,14 +291,18 @@ public final class NativeInterpreterWrapperTest {
ByteBuffer bbuf = ByteBuffer.allocateDirect(2 * 7 * 8 * 3);
bbuf.order(ByteOrder.nativeOrder());
Object[] inputs = {bbuf};
+ Map<Integer, Object> outputs = new HashMap<>();
+ ByteBuffer parsedOutput = ByteBuffer.allocateDirect(2 * 7 * 8 * 3);
+ outputs.put(0, parsedOutput);
try {
- wrapper.run(inputs);
+ wrapper.run(inputs, outputs);
fail();
} catch (IllegalArgumentException e) {
assertThat(e)
.hasMessageThat()
.contains(
- "Failed to get input dimensions. 0-th input should have 192 bytes, but found 336 bytes.");
+ "Cannot convert between a TensorFlowLite buffer with 192 bytes and a "
+ + "ByteBuffer with 336 bytes.");
}
wrapper.close();
}
@@ -308,14 +315,18 @@ public final class NativeInterpreterWrapperTest {
int[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
int[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
+ int[][][][] parsedOutputs = new int[2][8][8][3];
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
try {
- wrapper.run(inputs);
+ wrapper.run(inputs, outputs);
fail();
} catch (IllegalArgumentException e) {
assertThat(e)
.hasMessageThat()
.contains(
- "DataType (2) of input data does not match with the DataType (1) of model inputs.");
+ "Cannot convert between a TensorFlowLite tensor with type FLOAT32 and a Java object "
+ + "of type [[[[I (which is compatible with the TensorFlowLite type INT32)");
}
wrapper.close();
}
@@ -329,8 +340,11 @@ public final class NativeInterpreterWrapperTest {
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
+ float[][][][] parsedOutputs = new float[2][8][8][3];
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
try {
- wrapper.run(inputs);
+ wrapper.run(inputs, outputs);
fail();
} catch (IllegalArgumentException e) {
assertThat(e).hasMessageThat().contains("Invalid handle to Interpreter.");
@@ -342,7 +356,7 @@ public final class NativeInterpreterWrapperTest {
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
try {
Object[] inputs = {};
- wrapper.run(inputs);
+ wrapper.run(inputs, null);
fail();
} catch (IllegalArgumentException e) {
assertThat(e).hasMessageThat().contains("Inputs should not be null or empty.");
@@ -358,11 +372,14 @@ public final class NativeInterpreterWrapperTest {
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD, fourD};
+ float[][][][] parsedOutputs = new float[2][8][8][3];
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
try {
- wrapper.run(inputs);
+ wrapper.run(inputs, outputs);
fail();
} catch (IllegalArgumentException e) {
- assertThat(e).hasMessageThat().contains("Expected num of inputs is 1 but got 2");
+ assertThat(e).hasMessageThat().contains("Invalid input Tensor index: 1");
}
wrapper.close();
}
@@ -374,13 +391,18 @@ public final class NativeInterpreterWrapperTest {
float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD};
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
Object[] inputs = {threeD};
+ float[][][][] parsedOutputs = new float[2][8][8][3];
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
try {
- wrapper.run(inputs);
+ wrapper.run(inputs, outputs);
fail();
} catch (IllegalArgumentException e) {
assertThat(e)
.hasMessageThat()
- .contains("0-th input should have 4 dimensions, but found 3 dimensions");
+ .contains(
+ "Cannot copy between a TensorFlowLite tensor with shape [8, 7, 3] and a "
+ + "Java object with shape [2, 8, 8, 3].");
}
wrapper.close();
}
@@ -393,38 +415,23 @@ public final class NativeInterpreterWrapperTest {
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
+ float[][][][] parsedOutputs = new float[2][8][8][3];
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
try {
- wrapper.run(inputs);
+ wrapper.run(inputs, outputs);
fail();
} catch (IllegalArgumentException e) {
assertThat(e)
.hasMessageThat()
- .contains("0-th input dimension should be [?,8,8,3], but found [?,8,7,3]");
+ .contains(
+ "Cannot copy between a TensorFlowLite tensor with shape [2, 8, 7, 3] and a "
+ + "Java object with shape [2, 8, 8, 3].");
}
wrapper.close();
}
@Test
- public void testNumElements() {
- int[] shape = {2, 3, 4};
- int num = NativeInterpreterWrapper.numElements(shape);
- assertThat(num).isEqualTo(24);
- shape = null;
- num = NativeInterpreterWrapper.numElements(shape);
- assertThat(num).isEqualTo(0);
- }
-
- @Test
- public void testIsNonEmtpyArray() {
- assertThat(NativeInterpreterWrapper.isNonEmptyArray(null)).isFalse();
- assertThat(NativeInterpreterWrapper.isNonEmptyArray(3.2)).isFalse();
- int[] emptyArray = {};
- assertThat(NativeInterpreterWrapper.isNonEmptyArray(emptyArray)).isFalse();
- int[] validArray = {9, 5, 2, 1};
- assertThat(NativeInterpreterWrapper.isNonEmptyArray(validArray)).isTrue();
- }
-
- @Test
public void testDataTypeOf() {
float[] testEmtpyArray = {};
DataType dataType = NativeInterpreterWrapper.dataTypeOf(testEmtpyArray);
@@ -486,8 +493,10 @@ public final class NativeInterpreterWrapperTest {
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
- Tensor[] outputs = wrapper.run(inputs);
- assertThat(outputs.length).isEqualTo(1);
+ float[][][][] parsedOutputs = new float[2][8][8][3];
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
+ wrapper.run(inputs, outputs);
assertThat(wrapper.getLastNativeInferenceDurationNanoseconds()).isGreaterThan(0L);
wrapper.close();
}
@@ -507,13 +516,14 @@ public final class NativeInterpreterWrapperTest {
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
+ float[][][][] parsedOutputs = new float[2][8][8][3];
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
try {
- wrapper.run(inputs);
+ wrapper.run(inputs, outputs);
fail();
} catch (IllegalArgumentException e) {
- assertThat(e)
- .hasMessageThat()
- .contains("0-th input dimension should be [?,8,8,3], but found [?,8,7,3]");
+ // Expected.
}
assertThat(wrapper.getLastNativeInferenceDurationNanoseconds()).isNull();
wrapper.close();
@@ -523,41 +533,7 @@ public final class NativeInterpreterWrapperTest {
public void testGetInputDims() {
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
int[] expectedDims = {1, 8, 8, 3};
- assertThat(wrapper.getInputDims(0)).isEqualTo(expectedDims);
- wrapper.close();
- }
-
- @Test
- public void testGetInputDimsOutOfRange() {
- NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
- try {
- wrapper.getInputDims(-1);
- fail();
- } catch (IllegalArgumentException e) {
- assertThat(e).hasMessageThat().contains("Out of range");
- }
- try {
- wrapper.getInputDims(1);
- fail();
- } catch (IllegalArgumentException e) {
- assertThat(e).hasMessageThat().contains("Out of range");
- }
- wrapper.close();
- }
-
- @Test
- public void testGetOutputDataType() {
- NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
- assertThat(wrapper.getOutputDataType(0)).contains("float");
- wrapper.close();
- wrapper = new NativeInterpreterWrapper(LONG_MODEL_PATH);
- assertThat(wrapper.getOutputDataType(0)).contains("long");
- wrapper.close();
- wrapper = new NativeInterpreterWrapper(INT_MODEL_PATH);
- assertThat(wrapper.getOutputDataType(0)).contains("int");
- wrapper.close();
- wrapper = new NativeInterpreterWrapper(BYTE_MODEL_PATH);
- assertThat(wrapper.getOutputDataType(0)).contains("byte");
+ assertThat(wrapper.getInputTensor(0).shape()).isEqualTo(expectedDims);
wrapper.close();
}
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java
index dd9d37eeda..fe5926f6de 100644
--- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java
@@ -18,9 +18,10 @@ package org.tensorflow.lite;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.fail;
-import java.nio.BufferOverflowException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
+import java.util.HashMap;
+import java.util.Map;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
@@ -35,7 +36,7 @@ public final class TensorTest {
"tensorflow/contrib/lite/java/src/testdata/add.bin";
private NativeInterpreterWrapper wrapper;
- private long nativeHandle;
+ private Tensor tensor;
@Before
public void setUp() {
@@ -45,8 +46,10 @@ public final class TensorTest {
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
- Tensor[] outputs = wrapper.run(inputs);
- nativeHandle = outputs[0].nativeHandle;
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, new float[2][8][8][3]);
+ wrapper.run(inputs, outputs);
+ tensor = wrapper.getOutputTensor(0);
}
@After
@@ -55,17 +58,16 @@ public final class TensorTest {
}
@Test
- public void testFromHandle() throws Exception {
- Tensor tensor = Tensor.fromHandle(nativeHandle);
+ public void testBasic() throws Exception {
assertThat(tensor).isNotNull();
int[] expectedShape = {2, 8, 8, 3};
- assertThat(tensor.shapeCopy).isEqualTo(expectedShape);
- assertThat(tensor.dtype).isEqualTo(DataType.FLOAT32);
+ assertThat(tensor.shape()).isEqualTo(expectedShape);
+ assertThat(tensor.dataType()).isEqualTo(DataType.FLOAT32);
+ assertThat(tensor.numBytes()).isEqualTo(2 * 8 * 8 * 3 * 4);
}
@Test
public void testCopyTo() {
- Tensor tensor = Tensor.fromHandle(nativeHandle);
float[][][][] parsedOutputs = new float[2][8][8][3];
tensor.copyTo(parsedOutputs);
float[] outputOneD = parsedOutputs[0][0][0];
@@ -75,7 +77,6 @@ public final class TensorTest {
@Test
public void testCopyToByteBuffer() {
- Tensor tensor = Tensor.fromHandle(nativeHandle);
ByteBuffer parsedOutput =
ByteBuffer.allocateDirect(2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder());
tensor.copyTo(parsedOutput);
@@ -89,19 +90,17 @@ public final class TensorTest {
@Test
public void testCopyToInvalidByteBuffer() {
- Tensor tensor = Tensor.fromHandle(nativeHandle);
ByteBuffer parsedOutput = ByteBuffer.allocateDirect(3 * 4).order(ByteOrder.nativeOrder());
try {
tensor.copyTo(parsedOutput);
fail();
- } catch (BufferOverflowException e) {
+ } catch (IllegalArgumentException e) {
// Expected.
}
}
@Test
public void testCopyToWrongType() {
- Tensor tensor = Tensor.fromHandle(nativeHandle);
int[][][][] parsedOutputs = new int[2][8][8][3];
try {
tensor.copyTo(parsedOutputs);
@@ -110,15 +109,13 @@ public final class TensorTest {
assertThat(e)
.hasMessageThat()
.contains(
- "Cannot convert an TensorFlowLite tensor with type "
- + "FLOAT32 to a Java object of type [[[[I (which is compatible with the TensorFlowLite "
- + "type INT32)");
+ "Cannot convert between a TensorFlowLite tensor with type FLOAT32 and a Java object "
+ + "of type [[[[I (which is compatible with the TensorFlowLite type INT32)");
}
}
@Test
public void testCopyToWrongShape() {
- Tensor tensor = Tensor.fromHandle(nativeHandle);
float[][][][] parsedOutputs = new float[1][8][8][3];
try {
tensor.copyTo(parsedOutputs);
@@ -127,8 +124,50 @@ public final class TensorTest {
assertThat(e)
.hasMessageThat()
.contains(
- "Shape of output target [1, 8, 8, 3] does not match "
- + "with the shape of the Tensor [2, 8, 8, 3].");
+ "Cannot copy between a TensorFlowLite tensor with shape [2, 8, 8, 3] "
+ + "and a Java object with shape [1, 8, 8, 3].");
}
}
+
+ @Test
+ public void testSetTo() {
+ float[][][][] input = new float[2][8][8][3];
+ float[][][][] output = new float[2][8][8][3];
+ ByteBuffer inputByteBuffer =
+ ByteBuffer.allocateDirect(2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder());
+
+ input[0][0][0][0] = 2.0f;
+ tensor.setTo(input);
+ tensor.copyTo(output);
+ assertThat(output[0][0][0][0]).isEqualTo(2.0f);
+
+ inputByteBuffer.putFloat(0, 3.0f);
+ tensor.setTo(inputByteBuffer);
+ tensor.copyTo(output);
+ assertThat(output[0][0][0][0]).isEqualTo(3.0f);
+ }
+
+ @Test
+ public void testSetToInvalidByteBuffer() {
+ ByteBuffer input = ByteBuffer.allocateDirect(3 * 4).order(ByteOrder.nativeOrder());
+ try {
+ tensor.setTo(input);
+ fail();
+ } catch (IllegalArgumentException e) {
+ // Success.
+ }
+ }
+
+ @Test
+ public void testGetInputShapeIfDifferent() {
+ ByteBuffer bytBufferInput = ByteBuffer.allocateDirect(3 * 4).order(ByteOrder.nativeOrder());
+ assertThat(tensor.getInputShapeIfDifferent(bytBufferInput)).isNull();
+
+ float[][][][] sameShapeInput = new float[2][8][8][3];
+ assertThat(tensor.getInputShapeIfDifferent(sameShapeInput)).isNull();
+
+ float[][][][] differentShapeInput = new float[1][8][8][3];
+ assertThat(tensor.getInputShapeIfDifferent(differentShapeInput))
+ .isEqualTo(new int[] {1, 8, 8, 3});
+ }
}
diff --git a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java
index 3aef0c3bb6..c23521c077 100644
--- a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java
+++ b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java
@@ -58,7 +58,7 @@ public class TestHelper {
*/
public static int[] getInputDims(Interpreter interpreter, int index) {
if (interpreter != null && interpreter.wrapper != null) {
- return interpreter.wrapper.getInputDims(index);
+ return interpreter.wrapper.getInputTensor(index).shape();
} else {
throw new IllegalArgumentException(
"Interpreter has not initialized;" + " Failed to get input dimensions.");
@@ -77,7 +77,7 @@ public class TestHelper {
*/
public static String getOutputDataType(Interpreter interpreter, int index) {
if (interpreter != null && interpreter.wrapper != null) {
- return interpreter.wrapper.getOutputDataType(index);
+ return interpreter.wrapper.getOutputTensor(index).dataType().toStringName();
} else {
throw new IllegalArgumentException(
"Interpreter has not initialized;" + " Failed to get output data type.");