diff options
author | Jared Duke <jdduke@google.com> | 2018-07-09 14:11:22 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-09 14:14:53 -0700 |
commit | 09924e82d576011dd9509e952a756dac1e3b9c60 (patch) | |
tree | 55700ca0310e4709dd5cd8f425825fc648cd452f /tensorflow/contrib/lite/java | |
parent | 16ffd7c6cf05f0817d584acca90ea195b19b0530 (diff) |
Implement Interpreter.run() in terms of Tensor APIs
PiperOrigin-RevId: 203826817
Diffstat (limited to 'tensorflow/contrib/lite/java')
15 files changed, 568 insertions, 627 deletions
diff --git a/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java b/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java index 56f3e7604a..1587c3c56f 100644 --- a/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java +++ b/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java @@ -127,12 +127,8 @@ public final class OvicClassifierTest { try { testResult = classifier.classifyByteBuffer(testImage); fail(); - } catch (RuntimeException e) { - assertThat(e) - .hasMessageThat() - .contains( - "Failed to get input dimensions. 0-th input should have 49152 bytes, " - + "but found 150528 bytes."); + } catch (IllegalArgumentException e) { + // Success. } } diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java index 75334cd96e..94a1ec65d6 100644 --- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java @@ -27,10 +27,7 @@ enum DataType { UINT8(3), /** 64-bit signed integer. */ - INT64(4), - - /** A {@link ByteBuffer}. */ - BYTEBUFFER(999); + INT64(4); private final int value; @@ -69,8 +66,6 @@ enum DataType { return 1; case INT64: return 8; - case BYTEBUFFER: - return 1; } throw new IllegalArgumentException( "DataType error: DataType " + this + " is not supported yet"); @@ -87,8 +82,6 @@ enum DataType { return "byte"; case INT64: return "long"; - case BYTEBUFFER: - return "ByteBuffer"; } throw new IllegalArgumentException( "DataType error: DataType " + this + " is not supported yet"); diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java index 589fd6426f..7002f82677 100644 --- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java @@ -165,20 +165,7 @@ public final class Interpreter implements AutoCloseable { if (wrapper == null) { throw new IllegalStateException("Internal error: The Interpreter has already been closed."); } - Tensor[] tensors = wrapper.run(inputs); - if (outputs == null || tensors == null || outputs.size() > tensors.length) { - throw new IllegalArgumentException("Output error: Outputs do not match with model outputs."); - } - final int size = tensors.length; - for (Integer idx : outputs.keySet()) { - if (idx == null || idx < 0 || idx >= size) { - throw new IllegalArgumentException( - String.format( - "Output error: Invalid index of output %d (should be in range [0, %d))", - idx, size)); - } - tensors[idx].copyTo(outputs.get(idx)); - } + wrapper.run(inputs, outputs); } /** diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java index 80de88b6a1..072cb26bb2 100644 --- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java @@ -19,6 +19,7 @@ import java.lang.reflect.Array; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.MappedByteBuffer; +import java.util.Arrays; import java.util.HashMap; import java.util.Map; @@ -40,6 +41,8 @@ final class NativeInterpreterWrapper implements AutoCloseable { modelHandle = createModel(modelPath, errorHandle); interpreterHandle = createInterpreter(modelHandle, errorHandle, numThreads); isMemoryAllocated = true; + inputTensors = new Tensor[getInputCount(interpreterHandle)]; + outputTensors = new Tensor[getOutputCount(interpreterHandle)]; } /** @@ -72,6 +75,8 @@ final class NativeInterpreterWrapper implements AutoCloseable { modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle); interpreterHandle = createInterpreter(modelHandle, errorHandle, numThreads); isMemoryAllocated = true; + inputTensors = new Tensor[getInputCount(interpreterHandle)]; + outputTensors = new Tensor[getOutputCount(interpreterHandle)]; } /** Releases resources associated with this {@code NativeInterpreterWrapper}. */ @@ -85,75 +90,63 @@ final class NativeInterpreterWrapper implements AutoCloseable { inputsIndexes = null; outputsIndexes = null; isMemoryAllocated = false; + Arrays.fill(inputTensors, null); + Arrays.fill(outputTensors, null); } /** Sets inputs, runs model inference and returns outputs. */ - Tensor[] run(Object[] inputs) { + void run(Object[] inputs, Map<Integer, Object> outputs) { + inferenceDurationNanoseconds = -1; if (inputs == null || inputs.length == 0) { throw new IllegalArgumentException("Input error: Inputs should not be null or empty."); } - int[] dataTypes = new int[inputs.length]; - Object[] sizes = new Object[inputs.length]; - int[] numsOfBytes = new int[inputs.length]; + if (outputs == null || outputs.isEmpty()) { + throw new IllegalArgumentException("Input error: Outputs should not be null or empty."); + } + + // TODO(b/80431971): Remove implicit resize after deprecating multi-dimensional array inputs. + // Rather than forcing an immediate resize + allocation if an input's shape differs, we first + // flush all resizes, avoiding redundant allocations. for (int i = 0; i < inputs.length; ++i) { - DataType dataType = dataTypeOf(inputs[i]); - dataTypes[i] = dataType.getNumber(); - if (dataType == DataType.BYTEBUFFER) { - ByteBuffer buffer = (ByteBuffer) inputs[i]; - if (buffer == null || !buffer.isDirect() || buffer.order() != ByteOrder.nativeOrder()) { - throw new IllegalArgumentException( - "Input error: ByteBuffer should be a direct ByteBuffer that uses " - + "ByteOrder.nativeOrder()."); - } - numsOfBytes[i] = buffer.limit(); - sizes[i] = getInputDims(interpreterHandle, i, numsOfBytes[i]); - } else if (isNonEmptyArray(inputs[i])) { - int[] dims = shapeOf(inputs[i]); - sizes[i] = dims; - numsOfBytes[i] = dataType.elemByteSize() * numElements(dims); - } else { - throw new IllegalArgumentException( - String.format( - "Input error: %d-th element of the %d inputs is not an array or a ByteBuffer.", - i, inputs.length)); + Tensor tensor = getInputTensor(i); + int[] newShape = tensor.getInputShapeIfDifferent(inputs[i]); + if (newShape != null) { + resizeInput(i, newShape); } } - inferenceDurationNanoseconds = -1; - long[] outputsHandles = - run( - interpreterHandle, - errorHandle, - sizes, - dataTypes, - numsOfBytes, - inputs, - this, - isMemoryAllocated); - if (outputsHandles == null || outputsHandles.length == 0) { - throw new IllegalStateException("Internal error: Interpreter has no outputs."); + + if (!isMemoryAllocated) { + allocateTensors(interpreterHandle, errorHandle); + isMemoryAllocated = true; + // Allocation can trigger dynamic resizing of output tensors, so clear the + // output tensor cache. + Arrays.fill(outputTensors, null); } - isMemoryAllocated = true; - Tensor[] outputs = new Tensor[outputsHandles.length]; - for (int i = 0; i < outputsHandles.length; ++i) { - outputs[i] = Tensor.fromHandle(outputsHandles[i]); + + for (int i = 0; i < inputs.length; ++i) { + getInputTensor(i).setTo(inputs[i]); + } + + long inferenceStartNanos = System.nanoTime(); + run(interpreterHandle, errorHandle); + long inferenceDurationNanoseconds = System.nanoTime() - inferenceStartNanos; + + for (Map.Entry<Integer, Object> output : outputs.entrySet()) { + getOutputTensor(output.getKey()).copyTo(output.getValue()); } - return outputs; + + // Only set if the entire operation succeeds. + this.inferenceDurationNanoseconds = inferenceDurationNanoseconds; } - private static native long[] run( - long interpreterHandle, - long errorHandle, - Object[] sizes, - int[] dtypes, - int[] numsOfBytes, - Object[] values, - NativeInterpreterWrapper wrapper, - boolean memoryAllocated); + private static native boolean run(long interpreterHandle, long errorHandle); /** Resizes dimensions of a specific input. */ void resizeInput(int idx, int[] dims) { if (resizeInput(interpreterHandle, errorHandle, idx, dims)) { isMemoryAllocated = false; + // Resizing will invalidate the Tensor's shape, so invalidate the Tensor handle. + inputTensors[idx] = null; } } @@ -212,21 +205,6 @@ final class NativeInterpreterWrapper implements AutoCloseable { } } - static int numElements(int[] shape) { - if (shape == null) { - return 0; - } - int n = 1; - for (int i = 0; i < shape.length; i++) { - n *= shape[i]; - } - return n; - } - - static boolean isNonEmptyArray(Object o) { - return (o != null && o.getClass().isArray() && Array.getLength(o) != 0); - } - /** Returns the type of the data. */ static DataType dataTypeOf(Object o) { if (o != null) { @@ -242,8 +220,6 @@ final class NativeInterpreterWrapper implements AutoCloseable { return DataType.UINT8; } else if (long.class.equals(c)) { return DataType.INT64; - } else if (ByteBuffer.class.isInstance(o)) { - return DataType.BYTEBUFFER; } } throw new IllegalArgumentException( @@ -293,40 +269,55 @@ final class NativeInterpreterWrapper implements AutoCloseable { } /** - * Gets the dimensions of an input. It throws IllegalArgumentException if input index is invalid. + * Gets the quantization zero point of an output. + * + * @throws IllegalArgumentException if the output index is invalid. */ - int[] getInputDims(int index) { - return getInputDims(interpreterHandle, index, -1); + int getOutputQuantizationZeroPoint(int index) { + return getOutputQuantizationZeroPoint(interpreterHandle, index); } /** - * Gets the dimensions of an input. If numBytes >= 0, it will check whether num of bytes match the - * input. + * Gets the quantization scale of an output. + * + * @throws IllegalArgumentException if the output index is invalid. */ - private static native int[] getInputDims(long interpreterHandle, int inputIdx, int numBytes); - - /** Gets the type of an output. It throws IllegalArgumentException if output index is invalid. */ - String getOutputDataType(int index) { - int type = getOutputDataType(interpreterHandle, index); - return DataType.fromNumber(type).toStringName(); + float getOutputQuantizationScale(int index) { + return getOutputQuantizationScale(interpreterHandle, index); } /** - * Gets the quantization zero point of an output. + * Gets the input {@link Tensor} for the provided input index. * - * @throws IllegalArgumentExeption if the output index is invalid. + * @throws IllegalArgumentException if the input index is invalid. */ - int getOutputQuantizationZeroPoint(int index) { - return getOutputQuantizationZeroPoint(interpreterHandle, index); + Tensor getInputTensor(int index) { + if (index < 0 || index >= inputTensors.length) { + throw new IllegalArgumentException("Invalid input Tensor index: " + index); + } + Tensor inputTensor = inputTensors[index]; + if (inputTensor == null) { + inputTensor = + inputTensors[index] = Tensor.fromHandle(getInputTensor(interpreterHandle, index)); + } + return inputTensor; } /** - * Gets the quantization scale of an output. + * Gets the output {@link Tensor} for the provided output index. * - * @throws IllegalArgumentExeption if the output index is invalid. + * @throws IllegalArgumentException if the output index is invalid. */ - float getOutputQuantizationScale(int index) { - return getOutputQuantizationScale(interpreterHandle, index); + Tensor getOutputTensor(int index) { + if (index < 0 || index >= outputTensors.length) { + throw new IllegalArgumentException("Invalid output Tensor index: " + index); + } + Tensor outputTensor = outputTensors[index]; + if (outputTensor == null) { + outputTensor = + outputTensors[index] = Tensor.fromHandle(getOutputTensor(interpreterHandle, index)); + } + return outputTensor; } private static native int getOutputDataType(long interpreterHandle, int outputIdx); @@ -343,18 +334,30 @@ final class NativeInterpreterWrapper implements AutoCloseable { private long modelHandle; - private int inputSize; - private long inferenceDurationNanoseconds = -1; private ByteBuffer modelByteBuffer; + // Lazily constructed maps of input and output names to input and output Tensor indexes. private Map<String, Integer> inputsIndexes; - private Map<String, Integer> outputsIndexes; + // Lazily constructed and populated arrays of input and output Tensor wrappers. + private final Tensor[] inputTensors; + private final Tensor[] outputTensors; + private boolean isMemoryAllocated = false; + private static native long allocateTensors(long interpreterHandle, long errorHandle); + + private static native long getInputTensor(long interpreterHandle, int inputIdx); + + private static native long getOutputTensor(long interpreterHandle, int outputIdx); + + private static native int getInputCount(long interpreterHandle); + + private static native int getOutputCount(long interpreterHandle); + private static native String[] getInputNames(long interpreterHandle); private static native String[] getOutputNames(long interpreterHandle); diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java index b2a3e04c55..2c74c82417 100644 --- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java @@ -31,43 +31,122 @@ final class Tensor { return new Tensor(nativeHandle); } + /** Returns the {@link DataType} of elements stored in the Tensor. */ + public DataType dataType() { + return dtype; + } + + /** Returns the size, in bytes, of the tensor data. */ + public int numBytes() { + return numBytes(nativeHandle); + } + + /** + * Returns the <a href="https://www.tensorflow.org/resources/dims_types.html#shape">shape</a> of + * the Tensor, i.e., the sizes of each dimension. + * + * @return an array where the i-th element is the size of the i-th dimension of the tensor. + */ + public int[] shape() { + return shapeCopy; + } + + /** + * Copies the contents of the provided {@code src} object to the Tensor. + * + * <p>The {@code src} should either be a (multi-dimensional) array with a shape matching that of + * this tensor, or a {@link ByteByffer} of compatible primitive type with a matching flat size. + * + * @throws IllegalArgumentException if the tensor is a scalar or if {@code src} is not compatible + * with the tensor (for example, mismatched data types or shapes). + */ + void setTo(Object src) { + throwExceptionIfTypeIsIncompatible(src); + if (isByteBuffer(src)) { + ByteBuffer srcBuffer = (ByteBuffer) src; + // For direct ByteBuffer instances we support zero-copy. Note that this assumes the caller + // retains ownership of the source buffer until inference has completed. + if (srcBuffer.isDirect() && srcBuffer.order() == ByteOrder.nativeOrder()) { + writeDirectBuffer(nativeHandle, srcBuffer); + } else { + buffer().put(srcBuffer); + } + return; + } + writeMultiDimensionalArray(nativeHandle, src); + } + /** * Copies the contents of the tensor to {@code dst} and returns {@code dst}. * * @param dst the destination buffer, either an explicitly-typed array or a {@link ByteBuffer}. * @throws IllegalArgumentException if {@code dst} is not compatible with the tensor (for example, * mismatched data types or shapes). - * @throws BufferOverflowException If {@code dst} is a ByteBuffer with insufficient space for the - * data in this tensor. */ - <T> T copyTo(T dst) { + Object copyTo(Object dst) { + throwExceptionIfTypeIsIncompatible(dst); if (dst instanceof ByteBuffer) { ByteBuffer dstByteBuffer = (ByteBuffer) dst; dstByteBuffer.put(buffer()); return dst; } - if (NativeInterpreterWrapper.dataTypeOf(dst) != dtype) { + readMultiDimensionalArray(nativeHandle, dst); + return dst; + } + + /** Returns the provided buffer's shape if specified and different from this Tensor's shape. */ + // TODO(b/80431971): Remove this method after deprecating multi-dimensional array inputs. + int[] getInputShapeIfDifferent(Object input) { + // Implicit resizes based on ByteBuffer capacity isn't supported, so short-circuit that path. + // The ByteBuffer's size will be validated against this Tensor's size in {@link #setTo(Object)}. + if (isByteBuffer(input)) { + return null; + } + int[] inputShape = NativeInterpreterWrapper.shapeOf(input); + if (Arrays.equals(shapeCopy, inputShape)) { + return null; + } + return inputShape; + } + + private void throwExceptionIfTypeIsIncompatible(Object o) { + if (isByteBuffer(o)) { + ByteBuffer oBuffer = (ByteBuffer) o; + if (oBuffer.capacity() != numBytes()) { + throw new IllegalArgumentException( + String.format( + "Cannot convert between a TensorFlowLite buffer with %d bytes and a " + + "ByteBuffer with %d bytes.", + numBytes(), oBuffer.capacity())); + } + return; + } + DataType oType = NativeInterpreterWrapper.dataTypeOf(o); + if (oType != dtype) { throw new IllegalArgumentException( String.format( - "Output error: Cannot convert an TensorFlowLite tensor with type %s to a Java " - + "object of type %s (which is compatible with the TensorFlowLite type %s)", - dtype, dst.getClass().getName(), NativeInterpreterWrapper.dataTypeOf(dst))); + "Cannot convert between a TensorFlowLite tensor with type %s and a Java " + + "object of type %s (which is compatible with the TensorFlowLite type %s).", + dtype, o.getClass().getName(), oType)); } - int[] dstShape = NativeInterpreterWrapper.shapeOf(dst); - if (!Arrays.equals(dstShape, shapeCopy)) { + + int[] oShape = NativeInterpreterWrapper.shapeOf(o); + if (!Arrays.equals(oShape, shapeCopy)) { throw new IllegalArgumentException( String.format( - "Output error: Shape of output target %s does not match with the shape of the " - + "Tensor %s.", - Arrays.toString(dstShape), Arrays.toString(shapeCopy))); + "Cannot copy between a TensorFlowLite tensor with shape %s and a Java object " + + "with shape %s.", + Arrays.toString(shapeCopy), Arrays.toString(oShape))); } - readMultiDimensionalArray(nativeHandle, dst); - return dst; } - final long nativeHandle; - final DataType dtype; - final int[] shapeCopy; + private static boolean isByteBuffer(Object o) { + return o instanceof ByteBuffer; + } + + private final long nativeHandle; + private final DataType dtype; + private final int[] shapeCopy; private Tensor(long nativeHandle) { this.nativeHandle = nativeHandle; @@ -81,11 +160,17 @@ final class Tensor { private static native ByteBuffer buffer(long handle); + private static native void writeDirectBuffer(long handle, ByteBuffer src); + private static native int dtype(long handle); private static native int[] shape(long handle); - private static native void readMultiDimensionalArray(long handle, Object value); + private static native int numBytes(long handle); + + private static native void readMultiDimensionalArray(long handle, Object dst); + + private static native void writeMultiDimensionalArray(long handle, Object src); static { TensorFlowLite.init(); diff --git a/tensorflow/contrib/lite/java/src/main/native/BUILD b/tensorflow/contrib/lite/java/src/main/native/BUILD index 4399ed2025..4b4e1c21d8 100644 --- a/tensorflow/contrib/lite/java/src/main/native/BUILD +++ b/tensorflow/contrib/lite/java/src/main/native/BUILD @@ -11,7 +11,6 @@ licenses(["notice"]) # Apache 2.0 cc_library( name = "native_framework_only", srcs = [ - "duration_utils_jni.cc", "exception_jni.cc", "nativeinterpreterwrapper_jni.cc", "tensor_jni.cc", diff --git a/tensorflow/contrib/lite/java/src/main/native/duration_utils_jni.cc b/tensorflow/contrib/lite/java/src/main/native/duration_utils_jni.cc deleted file mode 100644 index 0e08a04370..0000000000 --- a/tensorflow/contrib/lite/java/src/main/native/duration_utils_jni.cc +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include <jni.h> -#include <time.h> - -namespace tflite { - -// Gets the elapsed wall-clock timespec. -timespec getCurrentTime() { - timespec time; - clock_gettime(CLOCK_MONOTONIC, &time); - return time; -} - -// Computes the time diff from two timespecs. Returns '-1' if 'stop' is earlier -// than 'start'. -jlong timespec_diff_nanoseconds(struct timespec* start, struct timespec* stop) { - jlong result = stop->tv_sec - start->tv_sec; - if (result < 0) return -1; - result = 1000000000 * result + (stop->tv_nsec - start->tv_nsec); - if (result < 0) return -1; - return result; -} - -} // namespace tflite diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc index 31f7b58fbc..e2c1edd9af 100644 --- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc +++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc @@ -16,9 +16,6 @@ limitations under the License. #include "tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h" namespace { -const int kByteBufferValue = 999; -const int kBufferSize = 256; - tflite::Interpreter* convertLongToInterpreter(JNIEnv* env, jlong handle) { if (handle == 0) { throwException(env, kIllegalArgumentException, @@ -62,22 +59,6 @@ std::vector<int> convertJIntArrayToVector(JNIEnv* env, jintArray inputs) { return outputs; } -bool isByteBuffer(jint data_type) { return data_type == kByteBufferValue; } - -TfLiteType resolveDataType(jint data_type) { - switch (data_type) { - case 1: - return kTfLiteFloat32; - case 2: - return kTfLiteInt32; - case 3: - return kTfLiteUInt8; - case 4: - return kTfLiteInt64; - default: - return kTfLiteNoType; - } -} int getDataType(TfLiteType data_type) { switch (data_type) { @@ -108,64 +89,6 @@ void printDims(char* buffer, int max_size, int* dims, int num_dims) { } } -TfLiteStatus checkInputs(JNIEnv* env, tflite::Interpreter* interpreter, - const int input_size, jintArray data_types, - jintArray nums_of_bytes, jobjectArray values, - jobjectArray sizes) { - if (input_size != interpreter->inputs().size()) { - throwException(env, kIllegalArgumentException, - "Input error: Expected num of inputs is %d but got %d", - interpreter->inputs().size(), input_size); - return kTfLiteError; - } - if (input_size != env->GetArrayLength(data_types) || - input_size != env->GetArrayLength(nums_of_bytes) || - input_size != env->GetArrayLength(values)) { - throwException(env, kIllegalArgumentException, - "Internal error: Arrays in arguments should be of the same " - "length, but got %d sizes, %d data_types, %d nums_of_bytes, " - "and %d values", - input_size, env->GetArrayLength(data_types), - env->GetArrayLength(nums_of_bytes), - env->GetArrayLength(values)); - return kTfLiteError; - } - for (int i = 0; i < input_size; ++i) { - int input_idx = interpreter->inputs()[i]; - TfLiteTensor* target = interpreter->tensor(input_idx); - jintArray dims = - static_cast<jintArray>(env->GetObjectArrayElement(sizes, i)); - int num_dims = static_cast<int>(env->GetArrayLength(dims)); - if (target->dims->size != num_dims) { - throwException(env, kIllegalArgumentException, - "Input error: %d-th input should have %d dimensions, but " - "found %d dimensions", - i, target->dims->size, num_dims); - return kTfLiteError; - } - jint* ptr = env->GetIntArrayElements(dims, nullptr); - for (int j = 1; j < num_dims; ++j) { - if (target->dims->data[j] != ptr[j]) { - std::unique_ptr<char[]> expected_dims(new char[kBufferSize]); - std::unique_ptr<char[]> obtained_dims(new char[kBufferSize]); - printDims(expected_dims.get(), kBufferSize, target->dims->data, - num_dims); - printDims(obtained_dims.get(), kBufferSize, ptr, num_dims); - throwException(env, kIllegalArgumentException, - "Input error: %d-th input dimension should be [%s], but " - "found [%s]", - i, expected_dims.get(), obtained_dims.get()); - env->ReleaseIntArrayElements(dims, ptr, JNI_ABORT); - return kTfLiteError; - } - } - env->ReleaseIntArrayElements(dims, ptr, JNI_ABORT); - env->DeleteLocalRef(dims); - if (env->ExceptionCheck()) return kTfLiteError; - } - return kTfLiteOk; -} - // Checks whether there is any difference between dimensions of a tensor and a // given dimensions. Returns true if there is difference, else false. bool areDimsDifferent(JNIEnv* env, TfLiteTensor* tensor, jintArray dims) { @@ -188,74 +111,6 @@ bool areDimsDifferent(JNIEnv* env, TfLiteTensor* tensor, jintArray dims) { return false; } -bool areInputDimensionsTheSame(JNIEnv* env, tflite::Interpreter* interpreter, - int input_size, jobjectArray sizes) { - if (interpreter->inputs().size() != input_size) { - return false; - } - for (int i = 0; i < input_size; ++i) { - int input_idx = interpreter->inputs()[i]; - jintArray dims = - static_cast<jintArray>(env->GetObjectArrayElement(sizes, i)); - TfLiteTensor* target = interpreter->tensor(input_idx); - if (areDimsDifferent(env, target, dims)) return false; - env->DeleteLocalRef(dims); - if (env->ExceptionCheck()) return false; - } - return true; -} - -TfLiteStatus resizeInputs(JNIEnv* env, tflite::Interpreter* interpreter, - int input_size, jobjectArray sizes) { - for (int i = 0; i < input_size; ++i) { - int input_idx = interpreter->inputs()[i]; - jintArray dims = - static_cast<jintArray>(env->GetObjectArrayElement(sizes, i)); - TfLiteStatus status = interpreter->ResizeInputTensor( - input_idx, convertJIntArrayToVector(env, dims)); - if (status != kTfLiteOk) { - return status; - } - env->DeleteLocalRef(dims); - if (env->ExceptionCheck()) return kTfLiteError; - } - return kTfLiteOk; -} - -TfLiteStatus setInputs(JNIEnv* env, tflite::Interpreter* interpreter, - int input_size, jintArray data_types, - jintArray nums_of_bytes, jobjectArray values) { - jint* data_type = env->GetIntArrayElements(data_types, nullptr); - jint* num_bytes = env->GetIntArrayElements(nums_of_bytes, nullptr); - for (int i = 0; i < input_size; ++i) { - int input_idx = interpreter->inputs()[i]; - TfLiteTensor* target = interpreter->tensor(input_idx); - jobject value = env->GetObjectArrayElement(values, i); - bool is_byte_buffer = isByteBuffer(data_type[i]); - if (is_byte_buffer) { - writeByteBuffer(env, value, &(target->data.raw), - static_cast<int>(num_bytes[i])); - } else { - TfLiteType type = resolveDataType(data_type[i]); - if (type != target->type) { - throwException(env, kIllegalArgumentException, - "Input error: DataType (%d) of input data does not " - "match with the DataType (%d) of model inputs.", - type, target->type); - return kTfLiteError; - } - writeMultiDimensionalArray(env, value, target->type, target->dims->size, - &(target->data.raw), - static_cast<int>(num_bytes[i])); - } - env->DeleteLocalRef(value); - if (env->ExceptionCheck()) return kTfLiteError; - } - env->ReleaseIntArrayElements(data_types, data_type, JNI_ABORT); - env->ReleaseIntArrayElements(nums_of_bytes, num_bytes, JNI_ABORT); - return kTfLiteOk; -} - // TODO(yichengfan): evaluate the benefit to use tflite verifier. bool VerifyModel(const void* buf, size_t len) { flatbuffers::Verifier verifier(static_cast<const uint8_t*>(buf), len); @@ -287,6 +142,63 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputNames(JNIEnv* env, return names; } +JNIEXPORT void JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_allocateTensors( + JNIEnv* env, jclass clazz, jlong handle, jlong error_handle) { + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return; + BufferErrorReporter* error_reporter = + convertLongToErrorReporter(env, error_handle); + if (error_reporter == nullptr) return; + + if (interpreter->AllocateTensors() != kTfLiteOk) { + throwException(env, kNullPointerException, + "Internal error: Cannot allocate memory for the interpreter:" + " %s", + error_reporter->CachedErrorMessage()); + } +} + +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensor(JNIEnv* env, + jclass clazz, + jlong handle, + jint index) { + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return 0; + return reinterpret_cast<jlong>( + interpreter->tensor(interpreter->inputs()[index])); +} + +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensor(JNIEnv* env, + jclass clazz, + jlong handle, + jint index) { + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return 0; + return reinterpret_cast<jlong>( + interpreter->tensor(interpreter->outputs()[index])); +} + +JNIEXPORT jint JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputCount(JNIEnv* env, + jclass clazz, + jlong handle) { + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return 0; + return static_cast<jint>(interpreter->inputs().size()); +} + +JNIEXPORT jint JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputCount(JNIEnv* env, + jclass clazz, + jlong handle) { + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return 0; + return static_cast<jint>(interpreter->outputs().size()); +} + JNIEXPORT jobjectArray JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputNames(JNIEnv* env, jclass clazz, @@ -434,114 +346,21 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter( } // Sets inputs, runs inference, and returns outputs as long handles. -JNIEXPORT jlongArray JNICALL -Java_org_tensorflow_lite_NativeInterpreterWrapper_run( - JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle, - jobjectArray sizes, jintArray data_types, jintArray nums_of_bytes, - jobjectArray values, jobject wrapper, jboolean memory_allocated) { +JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_run( + JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle) { tflite::Interpreter* interpreter = convertLongToInterpreter(env, interpreter_handle); - if (interpreter == nullptr) return nullptr; + if (interpreter == nullptr) return; BufferErrorReporter* error_reporter = convertLongToErrorReporter(env, error_handle); - if (error_reporter == nullptr) return nullptr; - const int input_size = env->GetArrayLength(sizes); - // validates inputs - TfLiteStatus status = checkInputs(env, interpreter, input_size, data_types, - nums_of_bytes, values, sizes); - if (status != kTfLiteOk) return nullptr; - if (!memory_allocated || - !areInputDimensionsTheSame(env, interpreter, input_size, sizes)) { - // resizes inputs - status = resizeInputs(env, interpreter, input_size, sizes); - if (status != kTfLiteOk) { - throwException(env, kNullPointerException, - "Internal error: Can not resize the input: %s", - error_reporter->CachedErrorMessage()); - return nullptr; - } - // allocates memory - status = interpreter->AllocateTensors(); - if (status != kTfLiteOk) { - throwException(env, kNullPointerException, - "Internal error: Can not allocate memory for the given " - "inputs: %s", - error_reporter->CachedErrorMessage()); - return nullptr; - } - } - // sets inputs - status = setInputs(env, interpreter, input_size, data_types, nums_of_bytes, - values); - if (status != kTfLiteOk) return nullptr; - timespec beforeInference = ::tflite::getCurrentTime(); - // runs inference + if (error_reporter == nullptr) return; + if (interpreter->Invoke() != kTfLiteOk) { throwException(env, kIllegalArgumentException, "Internal error: Failed to run on the given Interpreter: %s", error_reporter->CachedErrorMessage()); - return nullptr; - } - timespec afterInference = ::tflite::getCurrentTime(); - jclass wrapper_clazz = env->GetObjectClass(wrapper); - jfieldID fid = - env->GetFieldID(wrapper_clazz, "inferenceDurationNanoseconds", "J"); - if (env->ExceptionCheck()) { - env->ExceptionClear(); - } else if (fid != nullptr) { - env->SetLongField( - wrapper, fid, - ::tflite::timespec_diff_nanoseconds(&beforeInference, &afterInference)); - } - // returns outputs - const std::vector<int>& results = interpreter->outputs(); - if (results.empty()) { - throwException( - env, kIllegalArgumentException, - "Internal error: The Interpreter does not have any outputs."); - return nullptr; - } - jlongArray outputs = env->NewLongArray(results.size()); - size_t size = results.size(); - for (int i = 0; i < size; ++i) { - TfLiteTensor* source = interpreter->tensor(results[i]); - jlong output = reinterpret_cast<jlong>(source); - env->SetLongArrayRegion(outputs, i, 1, &output); - } - return outputs; -} - -JNIEXPORT jintArray JNICALL -Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputDims( - JNIEnv* env, jclass clazz, jlong handle, jint input_idx, jint num_bytes) { - tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); - if (interpreter == nullptr) return nullptr; - const int idx = static_cast<int>(input_idx); - if (input_idx < 0 || input_idx >= interpreter->inputs().size()) { - throwException(env, kIllegalArgumentException, - "Input error: Out of range: Failed to get %d-th input out of" - " %d inputs", - input_idx, interpreter->inputs().size()); - return nullptr; - } - TfLiteTensor* target = interpreter->tensor(interpreter->inputs()[idx]); - int size = target->dims->size; - if (num_bytes >= 0) { // verifies num of bytes matches if num_bytes if valid. - int expected_num_bytes = elementByteSize(target->type); - for (int i = 0; i < size; ++i) { - expected_num_bytes *= target->dims->data[i]; - } - if (num_bytes != expected_num_bytes) { - throwException(env, kIllegalArgumentException, - "Input error: Failed to get input dimensions. %d-th input " - "should have %d bytes, but found %d bytes.", - idx, expected_num_bytes, num_bytes); - return nullptr; - } + return; } - jintArray outputs = env->NewIntArray(size); - env->SetIntArrayRegion(outputs, 0, size, &(target->dims->data[0])); - return outputs; } JNIEXPORT jint JNICALL diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h index 128ece4981..618fba480e 100644 --- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h +++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h @@ -29,9 +29,6 @@ limitations under the License. namespace tflite { // This is to be provided at link-time by a library. extern std::unique_ptr<OpResolver> CreateOpResolver(); -extern timespec getCurrentTime(); -extern jlong timespec_diff_nanoseconds(struct timespec* start, - struct timespec* stop); } // namespace tflite #ifdef __cplusplus @@ -40,6 +37,57 @@ extern "C" { /* * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: allocateTensors + * Signature: (JJ)V + */ +JNIEXPORT void JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_allocateTensors( + JNIEnv* env, jclass clazz, jlong handle, jlong error_handle); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: getInputTensor + * Signature: (JI)J + */ +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensor(JNIEnv* env, + jclass clazz, + jlong handle, + jint index); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: getOutputTensor + * Signature: (JI)J + */ +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensor(JNIEnv* env, + jclass clazz, + jlong handle, + jint index); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: getInputCount + * Signature: (J)I + */ +JNIEXPORT jint JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputCount(JNIEnv* env, + jclass clazz, + jlong handle); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: getOutputCount + * Signature: (J)I + */ +JNIEXPORT jint JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputCount(JNIEnv* env, + jclass clazz, + jlong handle); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper * Method: * Signature: (J)[Ljava/lang/Object; */ @@ -118,28 +166,11 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter( /* * Class: org_tensorflow_lite_NativeInterpreterWrapper - * Method: - * Signature: - * (JJ[Ljava/lang/Object;[I[I[Ljava/lang/Object;Ljava/lang/Object;Z)[J - */ -JNIEXPORT jlongArray JNICALL -Java_org_tensorflow_lite_NativeInterpreterWrapper_run( - JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle, - jobjectArray sizes, jintArray data_types, jintArray nums_of_bytes, - jobjectArray values, jobject wrapper, jboolean memory_allocated); - -/* - * Class: org_tensorflow_lite_NativeInterpreterWrapper - * Method: - * Signature: (JII)[I - * - * Gets input dimensions. If num_bytes is non-negative, it will check whether - * num_bytes matches num of bytes required by the input, and return null and - * throw IllegalArgumentException if not. + * Method: run + * Signature: (JJ)V */ -JNIEXPORT jintArray JNICALL -Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputDims( - JNIEnv* env, jclass clazz, jlong handle, jint input_idx, jint num_bytes); +JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_run( + JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle); /* * Class: org_tensorflow_lite_NativeInterpreterWrapper diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc index 08b4d04280..7ff96a3172 100644 --- a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc +++ b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc @@ -29,6 +29,35 @@ TfLiteTensor* convertLongToTensor(JNIEnv* env, jlong handle) { return reinterpret_cast<TfLiteTensor*>(handle); } +size_t elementByteSize(TfLiteType data_type) { + // The code in this file makes the assumption that the + // TensorFlow TF_DataTypes and the Java primitive types + // have the same byte sizes. Validate that: + switch (data_type) { + case kTfLiteFloat32: + static_assert(sizeof(jfloat) == 4, + "Interal error: Java float not compatible with " + "kTfLiteFloat"); + return 4; + case kTfLiteInt32: + static_assert(sizeof(jint) == 4, + "Interal error: Java int not compatible with kTfLiteInt"); + return 4; + case kTfLiteUInt8: + static_assert(sizeof(jbyte) == 1, + "Interal error: Java byte not compatible with " + "kTfLiteUInt8"); + return 1; + case kTfLiteInt64: + static_assert(sizeof(jlong) == 8, + "Interal error: Java long not compatible with " + "kTfLiteInt64"); + return 8; + default: + return 0; + } +} + size_t writeOneDimensionalArray(JNIEnv* env, jobject object, TfLiteType type, void* dst, size_t dst_size) { jarray array = static_cast<jarray>(object); @@ -141,48 +170,6 @@ size_t readMultiDimensionalArray(JNIEnv* env, TfLiteType data_type, char* src, } } -} // namespace - -size_t elementByteSize(TfLiteType data_type) { - // The code in this file makes the assumption that the - // TensorFlow TF_DataTypes and the Java primitive types - // have the same byte sizes. Validate that: - switch (data_type) { - case kTfLiteFloat32: - static_assert(sizeof(jfloat) == 4, - "Interal error: Java float not compatible with " - "kTfLiteFloat"); - return 4; - case kTfLiteInt32: - static_assert(sizeof(jint) == 4, - "Interal error: Java int not compatible with kTfLiteInt"); - return 4; - case kTfLiteUInt8: - static_assert(sizeof(jbyte) == 1, - "Interal error: Java byte not compatible with " - "kTfLiteUInt8"); - return 1; - case kTfLiteInt64: - static_assert(sizeof(jlong) == 8, - "Interal error: Java long not compatible with " - "kTfLiteInt64"); - return 8; - default: - return 0; - } -} - -size_t writeByteBuffer(JNIEnv* env, jobject object, char** dst, int dst_size) { - char* buf = static_cast<char*>(env->GetDirectBufferAddress(object)); - if (!buf) { - throwException(env, kIllegalArgumentException, - "Input ByteBuffer is not a direct buffer"); - return 0; - } - *dst = buf; - return dst_size; -} - size_t writeMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type, int dims_left, char** dst, int dst_size) { if (dims_left <= 1) { @@ -203,16 +190,37 @@ size_t writeMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type, } } +} // namespace + JNIEXPORT jobject JNICALL Java_org_tensorflow_lite_Tensor_buffer(JNIEnv* env, jclass clazz, jlong handle) { TfLiteTensor* tensor = convertLongToTensor(env, handle); if (tensor == nullptr) return nullptr; - + if (tensor->data.raw == nullptr) { + throwException(env, kIllegalArgumentException, + "Internal error: Tensor hasn't been allocated."); + return nullptr; + } return env->NewDirectByteBuffer(static_cast<void*>(tensor->data.raw), static_cast<jlong>(tensor->bytes)); } +JNIEXPORT void JNICALL Java_org_tensorflow_lite_Tensor_writeDirectBuffer( + JNIEnv* env, jclass clazz, jlong handle, jobject src) { + TfLiteTensor* tensor = convertLongToTensor(env, handle); + if (tensor == nullptr) return; + + char* src_data_raw = static_cast<char*>(env->GetDirectBufferAddress(src)); + if (!src_data_raw) { + throwException(env, kIllegalArgumentException, + "Input ByteBuffer is not a direct buffer"); + return; + } + + tensor->data.raw = src_data_raw; +} + JNIEXPORT void JNICALL Java_org_tensorflow_lite_Tensor_readMultiDimensionalArray(JNIEnv* env, jclass clazz, @@ -230,6 +238,27 @@ Java_org_tensorflow_lite_Tensor_readMultiDimensionalArray(JNIEnv* env, num_dims, static_cast<jarray>(value)); } +JNIEXPORT void JNICALL +Java_org_tensorflow_lite_Tensor_writeMultiDimensionalArray(JNIEnv* env, + jclass clazz, + jlong handle, + jobject src) { + TfLiteTensor* tensor = convertLongToTensor(env, handle); + if (tensor == nullptr) return; + if (tensor->data.raw == nullptr) { + throwException(env, kIllegalArgumentException, + "Internal error: Target Tensor hasn't been allocated."); + return; + } + if (tensor->dims->size == 0) { + throwException(env, kIllegalArgumentException, + "Internal error: Cannot copy empty/scalar Tensors."); + return; + } + writeMultiDimensionalArray(env, src, tensor->type, tensor->dims->size, + &tensor->data.raw, tensor->bytes); +} + JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_dtype(JNIEnv* env, jclass clazz, jlong handle) { @@ -247,3 +276,11 @@ Java_org_tensorflow_lite_Tensor_shape(JNIEnv* env, jclass clazz, jlong handle) { env->SetIntArrayRegion(result, 0, num_dims, tensor->dims->data); return result; } + +JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_numBytes(JNIEnv* env, + jclass clazz, + jlong handle) { + const TfLiteTensor* tensor = convertLongToTensor(env, handle); + if (tensor == nullptr) return 0; + return static_cast<jint>(tensor->bytes); +} diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h index 9ba95d9ac4..06e2546af8 100644 --- a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h +++ b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h @@ -34,6 +34,14 @@ JNIEXPORT jobject JNICALL Java_org_tensorflow_lite_Tensor_buffer(JNIEnv* env, /* * Class: org_tensorflow_lite_Tensor + * Method: writeDirectBuffer + * Signature: (JLjava/nio/ByteBuffer;) + */ +JNIEXPORT void JNICALL Java_org_tensorflow_lite_Tensor_writeDirectBuffer( + JNIEnv* env, jclass clazz, jlong handle, jobject src); + +/* + * Class: org_tensorflow_lite_Tensor * Method: dtype * Signature: (J)I */ @@ -52,6 +60,15 @@ JNIEXPORT jintArray JNICALL Java_org_tensorflow_lite_Tensor_shape(JNIEnv* env, /* * Class: org_tensorflow_lite_Tensor + * Method: numBytes + * Signature: (J)I + */ +JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_numBytes(JNIEnv* env, + jclass clazz, + jlong handle); + +/* + * Class: org_tensorflow_lite_Tensor * Method: readMultiDimensionalArray * Signature: (JLjava/lang/Object;) */ @@ -59,23 +76,18 @@ JNIEXPORT void JNICALL Java_org_tensorflow_lite_Tensor_readMultiDimensionalArray(JNIEnv* env, jclass clazz, jlong handle, - jobject value); + jobject dst); /* - * Finds the size of each data type. - */ -size_t elementByteSize(TfLiteType data_type); - -/* - * Writes data of a ByteBuffer into dest. - */ -size_t writeByteBuffer(JNIEnv* env, jobject object, char** dst, int dst_size); - -/* - * Writes a multi-dimensional array into dest. + * Class: org_tensorflow_lite_Tensor + * Method: writeMultidimensionalArray + * Signature: (JLjava/lang/Object;) */ -size_t writeMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type, - int dims_left, char** dst, int dst_size); +JNIEXPORT void JNICALL +Java_org_tensorflow_lite_Tensor_writeMultiDimensionalArray(JNIEnv* env, + jclass clazz, + jlong handle, + jobject src); #ifdef __cplusplus } // extern "C" diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java index 42096ef9a3..d66a73db94 100644 --- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java @@ -221,7 +221,9 @@ public final class InterpreterTest { assertThat(e) .hasMessageThat() .contains( - "DataType (2) of input data does not match with the DataType (1) of model inputs."); + "Cannot convert between a TensorFlowLite tensor with type " + + "FLOAT32 and a Java object of type [[[[I (which is compatible with the" + + " TensorFlowLite type INT32)"); } interpreter.close(); } @@ -241,8 +243,8 @@ public final class InterpreterTest { assertThat(e) .hasMessageThat() .contains( - "Cannot convert an TensorFlowLite tensor with type " - + "FLOAT32 to a Java object of type [[[[I (which is compatible with the" + "Cannot convert between a TensorFlowLite tensor with type " + + "FLOAT32 and a Java object of type [[[[I (which is compatible with the" + " TensorFlowLite type INT32)"); } interpreter.close(); diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java index 029e5853e2..46bdecf443 100644 --- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java @@ -20,6 +20,8 @@ import static org.junit.Assert.fail; import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.util.HashMap; +import java.util.Map; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -101,10 +103,10 @@ public final class NativeInterpreterWrapperTest { float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; float[][][][] fourD = {threeD, threeD}; Object[] inputs = {fourD}; - Tensor[] outputs = wrapper.run(inputs); - assertThat(outputs.length).isEqualTo(1); float[][][][] parsedOutputs = new float[2][8][8][3]; - outputs[0].copyTo(parsedOutputs); + Map<Integer, Object> outputs = new HashMap<>(); + outputs.put(0, parsedOutputs); + wrapper.run(inputs, outputs); float[] outputOneD = parsedOutputs[0][0][0]; float[] expected = {3.69f, -19.62f, 23.43f}; assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); @@ -119,11 +121,11 @@ public final class NativeInterpreterWrapperTest { float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; float[][][][] fourD = {threeD, threeD}; Object[] inputs = {fourD}; - Tensor[] outputs = wrapper.run(inputs); - assertThat(outputs).hasLength(1); ByteBuffer parsedOutput = ByteBuffer.allocateDirect(2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder()); - outputs[0].copyTo(parsedOutput); + Map<Integer, Object> outputs = new HashMap<>(); + outputs.put(0, parsedOutput); + wrapper.run(inputs, outputs); float[] outputOneD = { parsedOutput.getFloat(0), parsedOutput.getFloat(4), parsedOutput.getFloat(8) }; @@ -140,17 +142,16 @@ public final class NativeInterpreterWrapperTest { float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; float[][][][] fourD = {threeD, threeD}; Object[] inputs = {fourD}; - Tensor[] outputs = wrapper.run(inputs); - assertThat(outputs.length).isEqualTo(1); float[][][][] parsedOutputs = new float[2][8][8][3]; - outputs[0].copyTo(parsedOutputs); + Map<Integer, Object> outputs = new HashMap<>(); + outputs.put(0, parsedOutputs); + wrapper.run(inputs, outputs); float[] outputOneD = parsedOutputs[0][0][0]; float[] expected = {3.69f, -19.62f, 23.43f}; assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); - outputs = wrapper.run(inputs); - assertThat(outputs.length).isEqualTo(1); parsedOutputs = new float[2][8][8][3]; - outputs[0].copyTo(parsedOutputs); + outputs.put(0, parsedOutputs); + wrapper.run(inputs, outputs); outputOneD = parsedOutputs[0][0][0]; assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); wrapper.close(); @@ -164,10 +165,10 @@ public final class NativeInterpreterWrapperTest { int[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; int[][][][] fourD = {threeD, threeD}; Object[] inputs = {fourD}; - Tensor[] outputs = wrapper.run(inputs); - assertThat(outputs.length).isEqualTo(1); int[][][][] parsedOutputs = new int[2][4][4][12]; - outputs[0].copyTo(parsedOutputs); + Map<Integer, Object> outputs = new HashMap<>(); + outputs.put(0, parsedOutputs); + wrapper.run(inputs, outputs); int[] outputOneD = parsedOutputs[0][0][0]; int[] expected = {3, 7, -4, 3, 7, -4, 3, 7, -4, 3, 7, -4}; assertThat(outputOneD).isEqualTo(expected); @@ -182,10 +183,10 @@ public final class NativeInterpreterWrapperTest { long[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; long[][][][] fourD = {threeD, threeD}; Object[] inputs = {fourD}; - Tensor[] outputs = wrapper.run(inputs); - assertThat(outputs.length).isEqualTo(1); long[][][][] parsedOutputs = new long[2][4][4][12]; - outputs[0].copyTo(parsedOutputs); + Map<Integer, Object> outputs = new HashMap<>(); + outputs.put(0, parsedOutputs); + wrapper.run(inputs, outputs); long[] outputOneD = parsedOutputs[0][0][0]; long[] expected = {-892834092L, 923423L, 2123918239018L, -892834092L, 923423L, 2123918239018L, -892834092L, 923423L, 2123918239018L, -892834092L, 923423L, 2123918239018L}; @@ -203,10 +204,10 @@ public final class NativeInterpreterWrapperTest { Object[] inputs = {fourD}; int[] inputDims = {2, 8, 8, 3}; wrapper.resizeInput(0, inputDims); - Tensor[] outputs = wrapper.run(inputs); - assertThat(outputs.length).isEqualTo(1); byte[][][][] parsedOutputs = new byte[2][4][4][12]; - outputs[0].copyTo(parsedOutputs); + Map<Integer, Object> outputs = new HashMap<>(); + outputs.put(0, parsedOutputs); + wrapper.run(inputs, outputs); byte[] outputOneD = parsedOutputs[0][0][0]; byte[] expected = {(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0}; @@ -229,13 +230,14 @@ public final class NativeInterpreterWrapperTest { } } } + bbuf.rewind(); Object[] inputs = {bbuf}; int[] inputDims = {2, 8, 8, 3}; wrapper.resizeInput(0, inputDims); - Tensor[] outputs = wrapper.run(inputs); - assertThat(outputs.length).isEqualTo(1); byte[][][][] parsedOutputs = new byte[2][4][4][12]; - outputs[0].copyTo(parsedOutputs); + Map<Integer, Object> outputs = new HashMap<>(); + outputs.put(0, parsedOutputs); + wrapper.run(inputs, outputs); byte[] outputOneD = parsedOutputs[0][0][0]; byte[] expected = { (byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0, @@ -261,21 +263,22 @@ public final class NativeInterpreterWrapperTest { } } Object[] inputs = {bbuf}; + float[][][][] parsedOutputs = new float[4][8][8][3]; + Map<Integer, Object> outputs = new HashMap<>(); + outputs.put(0, parsedOutputs); try { - wrapper.run(inputs); + wrapper.run(inputs, outputs); fail(); } catch (IllegalArgumentException e) { assertThat(e) .hasMessageThat() .contains( - "Failed to get input dimensions. 0-th input should have 768 bytes, but found 3072 bytes"); + "Cannot convert between a TensorFlowLite buffer with 768 bytes and a " + + "ByteBuffer with 3072 bytes."); } int[] inputDims = {4, 8, 8, 3}; wrapper.resizeInput(0, inputDims); - Tensor[] outputs = wrapper.run(inputs); - assertThat(outputs.length).isEqualTo(1); - float[][][][] parsedOutputs = new float[4][8][8][3]; - outputs[0].copyTo(parsedOutputs); + wrapper.run(inputs, outputs); float[] outputOneD = parsedOutputs[0][0][0]; float[] expected = {3.69f, -19.62f, 23.43f}; assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); @@ -288,14 +291,18 @@ public final class NativeInterpreterWrapperTest { ByteBuffer bbuf = ByteBuffer.allocateDirect(2 * 7 * 8 * 3); bbuf.order(ByteOrder.nativeOrder()); Object[] inputs = {bbuf}; + Map<Integer, Object> outputs = new HashMap<>(); + ByteBuffer parsedOutput = ByteBuffer.allocateDirect(2 * 7 * 8 * 3); + outputs.put(0, parsedOutput); try { - wrapper.run(inputs); + wrapper.run(inputs, outputs); fail(); } catch (IllegalArgumentException e) { assertThat(e) .hasMessageThat() .contains( - "Failed to get input dimensions. 0-th input should have 192 bytes, but found 336 bytes."); + "Cannot convert between a TensorFlowLite buffer with 192 bytes and a " + + "ByteBuffer with 336 bytes."); } wrapper.close(); } @@ -308,14 +315,18 @@ public final class NativeInterpreterWrapperTest { int[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; int[][][][] fourD = {threeD, threeD}; Object[] inputs = {fourD}; + int[][][][] parsedOutputs = new int[2][8][8][3]; + Map<Integer, Object> outputs = new HashMap<>(); + outputs.put(0, parsedOutputs); try { - wrapper.run(inputs); + wrapper.run(inputs, outputs); fail(); } catch (IllegalArgumentException e) { assertThat(e) .hasMessageThat() .contains( - "DataType (2) of input data does not match with the DataType (1) of model inputs."); + "Cannot convert between a TensorFlowLite tensor with type FLOAT32 and a Java object " + + "of type [[[[I (which is compatible with the TensorFlowLite type INT32)"); } wrapper.close(); } @@ -329,8 +340,11 @@ public final class NativeInterpreterWrapperTest { float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; float[][][][] fourD = {threeD, threeD}; Object[] inputs = {fourD}; + float[][][][] parsedOutputs = new float[2][8][8][3]; + Map<Integer, Object> outputs = new HashMap<>(); + outputs.put(0, parsedOutputs); try { - wrapper.run(inputs); + wrapper.run(inputs, outputs); fail(); } catch (IllegalArgumentException e) { assertThat(e).hasMessageThat().contains("Invalid handle to Interpreter."); @@ -342,7 +356,7 @@ public final class NativeInterpreterWrapperTest { NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); try { Object[] inputs = {}; - wrapper.run(inputs); + wrapper.run(inputs, null); fail(); } catch (IllegalArgumentException e) { assertThat(e).hasMessageThat().contains("Inputs should not be null or empty."); @@ -358,11 +372,14 @@ public final class NativeInterpreterWrapperTest { float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; float[][][][] fourD = {threeD, threeD}; Object[] inputs = {fourD, fourD}; + float[][][][] parsedOutputs = new float[2][8][8][3]; + Map<Integer, Object> outputs = new HashMap<>(); + outputs.put(0, parsedOutputs); try { - wrapper.run(inputs); + wrapper.run(inputs, outputs); fail(); } catch (IllegalArgumentException e) { - assertThat(e).hasMessageThat().contains("Expected num of inputs is 1 but got 2"); + assertThat(e).hasMessageThat().contains("Invalid input Tensor index: 1"); } wrapper.close(); } @@ -374,13 +391,18 @@ public final class NativeInterpreterWrapperTest { float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD}; float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; Object[] inputs = {threeD}; + float[][][][] parsedOutputs = new float[2][8][8][3]; + Map<Integer, Object> outputs = new HashMap<>(); + outputs.put(0, parsedOutputs); try { - wrapper.run(inputs); + wrapper.run(inputs, outputs); fail(); } catch (IllegalArgumentException e) { assertThat(e) .hasMessageThat() - .contains("0-th input should have 4 dimensions, but found 3 dimensions"); + .contains( + "Cannot copy between a TensorFlowLite tensor with shape [8, 7, 3] and a " + + "Java object with shape [2, 8, 8, 3]."); } wrapper.close(); } @@ -393,38 +415,23 @@ public final class NativeInterpreterWrapperTest { float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; float[][][][] fourD = {threeD, threeD}; Object[] inputs = {fourD}; + float[][][][] parsedOutputs = new float[2][8][8][3]; + Map<Integer, Object> outputs = new HashMap<>(); + outputs.put(0, parsedOutputs); try { - wrapper.run(inputs); + wrapper.run(inputs, outputs); fail(); } catch (IllegalArgumentException e) { assertThat(e) .hasMessageThat() - .contains("0-th input dimension should be [?,8,8,3], but found [?,8,7,3]"); + .contains( + "Cannot copy between a TensorFlowLite tensor with shape [2, 8, 7, 3] and a " + + "Java object with shape [2, 8, 8, 3]."); } wrapper.close(); } @Test - public void testNumElements() { - int[] shape = {2, 3, 4}; - int num = NativeInterpreterWrapper.numElements(shape); - assertThat(num).isEqualTo(24); - shape = null; - num = NativeInterpreterWrapper.numElements(shape); - assertThat(num).isEqualTo(0); - } - - @Test - public void testIsNonEmtpyArray() { - assertThat(NativeInterpreterWrapper.isNonEmptyArray(null)).isFalse(); - assertThat(NativeInterpreterWrapper.isNonEmptyArray(3.2)).isFalse(); - int[] emptyArray = {}; - assertThat(NativeInterpreterWrapper.isNonEmptyArray(emptyArray)).isFalse(); - int[] validArray = {9, 5, 2, 1}; - assertThat(NativeInterpreterWrapper.isNonEmptyArray(validArray)).isTrue(); - } - - @Test public void testDataTypeOf() { float[] testEmtpyArray = {}; DataType dataType = NativeInterpreterWrapper.dataTypeOf(testEmtpyArray); @@ -486,8 +493,10 @@ public final class NativeInterpreterWrapperTest { float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; float[][][][] fourD = {threeD, threeD}; Object[] inputs = {fourD}; - Tensor[] outputs = wrapper.run(inputs); - assertThat(outputs.length).isEqualTo(1); + float[][][][] parsedOutputs = new float[2][8][8][3]; + Map<Integer, Object> outputs = new HashMap<>(); + outputs.put(0, parsedOutputs); + wrapper.run(inputs, outputs); assertThat(wrapper.getLastNativeInferenceDurationNanoseconds()).isGreaterThan(0L); wrapper.close(); } @@ -507,13 +516,14 @@ public final class NativeInterpreterWrapperTest { float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; float[][][][] fourD = {threeD, threeD}; Object[] inputs = {fourD}; + float[][][][] parsedOutputs = new float[2][8][8][3]; + Map<Integer, Object> outputs = new HashMap<>(); + outputs.put(0, parsedOutputs); try { - wrapper.run(inputs); + wrapper.run(inputs, outputs); fail(); } catch (IllegalArgumentException e) { - assertThat(e) - .hasMessageThat() - .contains("0-th input dimension should be [?,8,8,3], but found [?,8,7,3]"); + // Expected. } assertThat(wrapper.getLastNativeInferenceDurationNanoseconds()).isNull(); wrapper.close(); @@ -523,41 +533,7 @@ public final class NativeInterpreterWrapperTest { public void testGetInputDims() { NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); int[] expectedDims = {1, 8, 8, 3}; - assertThat(wrapper.getInputDims(0)).isEqualTo(expectedDims); - wrapper.close(); - } - - @Test - public void testGetInputDimsOutOfRange() { - NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); - try { - wrapper.getInputDims(-1); - fail(); - } catch (IllegalArgumentException e) { - assertThat(e).hasMessageThat().contains("Out of range"); - } - try { - wrapper.getInputDims(1); - fail(); - } catch (IllegalArgumentException e) { - assertThat(e).hasMessageThat().contains("Out of range"); - } - wrapper.close(); - } - - @Test - public void testGetOutputDataType() { - NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); - assertThat(wrapper.getOutputDataType(0)).contains("float"); - wrapper.close(); - wrapper = new NativeInterpreterWrapper(LONG_MODEL_PATH); - assertThat(wrapper.getOutputDataType(0)).contains("long"); - wrapper.close(); - wrapper = new NativeInterpreterWrapper(INT_MODEL_PATH); - assertThat(wrapper.getOutputDataType(0)).contains("int"); - wrapper.close(); - wrapper = new NativeInterpreterWrapper(BYTE_MODEL_PATH); - assertThat(wrapper.getOutputDataType(0)).contains("byte"); + assertThat(wrapper.getInputTensor(0).shape()).isEqualTo(expectedDims); wrapper.close(); } diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java index dd9d37eeda..fe5926f6de 100644 --- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java @@ -18,9 +18,10 @@ package org.tensorflow.lite; import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.fail; -import java.nio.BufferOverflowException; import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.util.HashMap; +import java.util.Map; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -35,7 +36,7 @@ public final class TensorTest { "tensorflow/contrib/lite/java/src/testdata/add.bin"; private NativeInterpreterWrapper wrapper; - private long nativeHandle; + private Tensor tensor; @Before public void setUp() { @@ -45,8 +46,10 @@ public final class TensorTest { float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; float[][][][] fourD = {threeD, threeD}; Object[] inputs = {fourD}; - Tensor[] outputs = wrapper.run(inputs); - nativeHandle = outputs[0].nativeHandle; + Map<Integer, Object> outputs = new HashMap<>(); + outputs.put(0, new float[2][8][8][3]); + wrapper.run(inputs, outputs); + tensor = wrapper.getOutputTensor(0); } @After @@ -55,17 +58,16 @@ public final class TensorTest { } @Test - public void testFromHandle() throws Exception { - Tensor tensor = Tensor.fromHandle(nativeHandle); + public void testBasic() throws Exception { assertThat(tensor).isNotNull(); int[] expectedShape = {2, 8, 8, 3}; - assertThat(tensor.shapeCopy).isEqualTo(expectedShape); - assertThat(tensor.dtype).isEqualTo(DataType.FLOAT32); + assertThat(tensor.shape()).isEqualTo(expectedShape); + assertThat(tensor.dataType()).isEqualTo(DataType.FLOAT32); + assertThat(tensor.numBytes()).isEqualTo(2 * 8 * 8 * 3 * 4); } @Test public void testCopyTo() { - Tensor tensor = Tensor.fromHandle(nativeHandle); float[][][][] parsedOutputs = new float[2][8][8][3]; tensor.copyTo(parsedOutputs); float[] outputOneD = parsedOutputs[0][0][0]; @@ -75,7 +77,6 @@ public final class TensorTest { @Test public void testCopyToByteBuffer() { - Tensor tensor = Tensor.fromHandle(nativeHandle); ByteBuffer parsedOutput = ByteBuffer.allocateDirect(2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder()); tensor.copyTo(parsedOutput); @@ -89,19 +90,17 @@ public final class TensorTest { @Test public void testCopyToInvalidByteBuffer() { - Tensor tensor = Tensor.fromHandle(nativeHandle); ByteBuffer parsedOutput = ByteBuffer.allocateDirect(3 * 4).order(ByteOrder.nativeOrder()); try { tensor.copyTo(parsedOutput); fail(); - } catch (BufferOverflowException e) { + } catch (IllegalArgumentException e) { // Expected. } } @Test public void testCopyToWrongType() { - Tensor tensor = Tensor.fromHandle(nativeHandle); int[][][][] parsedOutputs = new int[2][8][8][3]; try { tensor.copyTo(parsedOutputs); @@ -110,15 +109,13 @@ public final class TensorTest { assertThat(e) .hasMessageThat() .contains( - "Cannot convert an TensorFlowLite tensor with type " - + "FLOAT32 to a Java object of type [[[[I (which is compatible with the TensorFlowLite " - + "type INT32)"); + "Cannot convert between a TensorFlowLite tensor with type FLOAT32 and a Java object " + + "of type [[[[I (which is compatible with the TensorFlowLite type INT32)"); } } @Test public void testCopyToWrongShape() { - Tensor tensor = Tensor.fromHandle(nativeHandle); float[][][][] parsedOutputs = new float[1][8][8][3]; try { tensor.copyTo(parsedOutputs); @@ -127,8 +124,50 @@ public final class TensorTest { assertThat(e) .hasMessageThat() .contains( - "Shape of output target [1, 8, 8, 3] does not match " - + "with the shape of the Tensor [2, 8, 8, 3]."); + "Cannot copy between a TensorFlowLite tensor with shape [2, 8, 8, 3] " + + "and a Java object with shape [1, 8, 8, 3]."); } } + + @Test + public void testSetTo() { + float[][][][] input = new float[2][8][8][3]; + float[][][][] output = new float[2][8][8][3]; + ByteBuffer inputByteBuffer = + ByteBuffer.allocateDirect(2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder()); + + input[0][0][0][0] = 2.0f; + tensor.setTo(input); + tensor.copyTo(output); + assertThat(output[0][0][0][0]).isEqualTo(2.0f); + + inputByteBuffer.putFloat(0, 3.0f); + tensor.setTo(inputByteBuffer); + tensor.copyTo(output); + assertThat(output[0][0][0][0]).isEqualTo(3.0f); + } + + @Test + public void testSetToInvalidByteBuffer() { + ByteBuffer input = ByteBuffer.allocateDirect(3 * 4).order(ByteOrder.nativeOrder()); + try { + tensor.setTo(input); + fail(); + } catch (IllegalArgumentException e) { + // Success. + } + } + + @Test + public void testGetInputShapeIfDifferent() { + ByteBuffer bytBufferInput = ByteBuffer.allocateDirect(3 * 4).order(ByteOrder.nativeOrder()); + assertThat(tensor.getInputShapeIfDifferent(bytBufferInput)).isNull(); + + float[][][][] sameShapeInput = new float[2][8][8][3]; + assertThat(tensor.getInputShapeIfDifferent(sameShapeInput)).isNull(); + + float[][][][] differentShapeInput = new float[1][8][8][3]; + assertThat(tensor.getInputShapeIfDifferent(differentShapeInput)) + .isEqualTo(new int[] {1, 8, 8, 3}); + } } diff --git a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java index 3aef0c3bb6..c23521c077 100644 --- a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java +++ b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java @@ -58,7 +58,7 @@ public class TestHelper { */ public static int[] getInputDims(Interpreter interpreter, int index) { if (interpreter != null && interpreter.wrapper != null) { - return interpreter.wrapper.getInputDims(index); + return interpreter.wrapper.getInputTensor(index).shape(); } else { throw new IllegalArgumentException( "Interpreter has not initialized;" + " Failed to get input dimensions."); @@ -77,7 +77,7 @@ public class TestHelper { */ public static String getOutputDataType(Interpreter interpreter, int index) { if (interpreter != null && interpreter.wrapper != null) { - return interpreter.wrapper.getOutputDataType(index); + return interpreter.wrapper.getOutputTensor(index).dataType().toStringName(); } else { throw new IllegalArgumentException( "Interpreter has not initialized;" + " Failed to get output data type."); |