aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/java
diff options
context:
space:
mode:
authorGravatar Jared Duke <jdduke@google.com>2018-08-20 13:27:38 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-20 13:32:43 -0700
commite8894bdcda6c7fb899939406ff4f320d2c59b208 (patch)
tree9cbfa999ad57a27d085144a6efadfd9e3216e4b6 /tensorflow/contrib/lite/java
parent600caf99897e82cd0db8665acca5e7630ec1a292 (diff)
Extend Java Interpreter API for TensorFlow Lite
Expose simple Tensor and DataType Java classes that can be used for basic introspection. Note that this change does not allow direct mutation of Tensor objects. The client must still use the Interpreter.invoke() API for injecting and retrieving Tensor data. PiperOrigin-RevId: 209473412
Diffstat (limited to 'tensorflow/contrib/lite/java')
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java44
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java70
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java29
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java51
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java15
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java19
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java22
7 files changed, 181 insertions, 69 deletions
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java
index 94a1ec65d6..41093e8ffe 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java
@@ -15,8 +15,8 @@ limitations under the License.
package org.tensorflow.lite;
-/** Type of elements in a {@link TfLiteTensor}. */
-enum DataType {
+/** Represents the type of elements in a TensorFlow Lite {@link Tensor} as an enum. */
+public enum DataType {
/** 32-bit single precision floating point. */
FLOAT32(1),
@@ -35,13 +35,29 @@ enum DataType {
this.value = value;
}
- /** Corresponding value of the kTfLite* enum in the TensorFlow Lite CC API. */
- int getNumber() {
+ /** Returns the size of an element of this type, in bytes, or -1 if element size is variable. */
+ public int byteSize() {
+ switch (this) {
+ case FLOAT32:
+ return 4;
+ case INT32:
+ return 4;
+ case UINT8:
+ return 1;
+ case INT64:
+ return 8;
+ }
+ throw new IllegalArgumentException(
+ "DataType error: DataType " + this + " is not supported yet");
+ }
+
+ /** Corresponding value of the TfLiteType enum in the TensorFlow Lite C API. */
+ int c() {
return value;
}
- /** Converts an integer to the corresponding type. */
- static DataType fromNumber(int c) {
+ /** Converts a C TfLiteType enum value to the corresponding type. */
+ static DataType fromC(int c) {
for (DataType t : values) {
if (t.value == c) {
return t;
@@ -55,22 +71,6 @@ enum DataType {
+ ")");
}
- /** Returns byte size of the type. */
- int elemByteSize() {
- switch (this) {
- case FLOAT32:
- return 4;
- case INT32:
- return 4;
- case UINT8:
- return 1;
- case INT64:
- return 8;
- }
- throw new IllegalArgumentException(
- "DataType error: DataType " + this + " is not supported yet");
- }
-
/** Gets string names of the data type. */
String toStringName() {
switch (this) {
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
index 7002f82677..b84720ae8e 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
@@ -162,9 +162,7 @@ public final class Interpreter implements AutoCloseable {
*/
public void runForMultipleInputsOutputs(
@NonNull Object[] inputs, @NonNull Map<Integer, Object> outputs) {
- if (wrapper == null) {
- throw new IllegalStateException("Internal error: The Interpreter has already been closed.");
- }
+ checkNotClosed();
wrapper.run(inputs, outputs);
}
@@ -174,12 +172,16 @@ public final class Interpreter implements AutoCloseable {
* <p>IllegalArgumentException will be thrown if it fails to resize.
*/
public void resizeInput(int idx, @NonNull int[] dims) {
- if (wrapper == null) {
- throw new IllegalStateException("Internal error: The Interpreter has already been closed.");
- }
+ checkNotClosed();
wrapper.resizeInput(idx, dims);
}
+ /** Gets the number of input tensors. */
+ public int getInputTensorCount() {
+ checkNotClosed();
+ return wrapper.getInputTensorCount();
+ }
+
/**
* Gets index of an input given the op name of the input.
*
@@ -187,51 +189,65 @@ public final class Interpreter implements AutoCloseable {
* to initialize the {@link Interpreter}.
*/
public int getInputIndex(String opName) {
- if (wrapper == null) {
- throw new IllegalStateException("Internal error: The Interpreter has already been closed.");
- }
+ checkNotClosed();
return wrapper.getInputIndex(opName);
}
/**
+ * Gets the Tensor associated with the provdied input index.
+ *
+ * <p>IllegalArgumentException will be thrown if the provided index is invalid.
+ */
+ public Tensor getInputTensor(int inputIndex) {
+ checkNotClosed();
+ return wrapper.getInputTensor(inputIndex);
+ }
+
+ /** Gets the number of output Tensors. */
+ public int getOutputTensorCount() {
+ checkNotClosed();
+ return wrapper.getOutputTensorCount();
+ }
+
+ /**
* Gets index of an output given the op name of the output.
*
* <p>IllegalArgumentException will be thrown if the op name does not exist in the model file used
* to initialize the {@link Interpreter}.
*/
public int getOutputIndex(String opName) {
- if (wrapper == null) {
- throw new IllegalStateException("Internal error: The Interpreter has already been closed.");
- }
+ checkNotClosed();
return wrapper.getOutputIndex(opName);
}
/**
+ * Gets the Tensor associated with the provdied output index.
+ *
+ * <p>IllegalArgumentException will be thrown if the provided index is invalid.
+ */
+ public Tensor getOutputTensor(int outputIndex) {
+ checkNotClosed();
+ return wrapper.getOutputTensor(outputIndex);
+ }
+
+ /**
* Returns native inference timing.
* <p>IllegalArgumentException will be thrown if the model is not initialized by the
* {@link Interpreter}.
*/
public Long getLastNativeInferenceDurationNanoseconds() {
- if (wrapper == null) {
- throw new IllegalStateException("Internal error: The interpreter has already been closed.");
- }
+ checkNotClosed();
return wrapper.getLastNativeInferenceDurationNanoseconds();
}
/** Turns on/off Android NNAPI for hardware acceleration when it is available. */
public void setUseNNAPI(boolean useNNAPI) {
- if (wrapper != null) {
- wrapper.setUseNNAPI(useNNAPI);
- } else {
- throw new IllegalStateException(
- "Internal error: NativeInterpreterWrapper has already been closed.");
- }
+ checkNotClosed();
+ wrapper.setUseNNAPI(useNNAPI);
}
public void setNumThreads(int numThreads) {
- if (wrapper == null) {
- throw new IllegalStateException("The interpreter has already been closed.");
- }
+ checkNotClosed();
wrapper.setNumThreads(numThreads);
}
@@ -253,5 +269,11 @@ public final class Interpreter implements AutoCloseable {
}
}
+ private void checkNotClosed() {
+ if (wrapper == null) {
+ throw new IllegalStateException("Internal error: The Interpreter has already been closed.");
+ }
+ }
+
NativeInterpreterWrapper wrapper;
}
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
index 767a220f8c..fa25082304 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
@@ -114,12 +114,10 @@ final class NativeInterpreterWrapper implements AutoCloseable {
}
}
- if (!isMemoryAllocated) {
+ boolean needsAllocation = !isMemoryAllocated;
+ if (needsAllocation) {
allocateTensors(interpreterHandle, errorHandle);
isMemoryAllocated = true;
- // Allocation can trigger dynamic resizing of output tensors, so clear the
- // output tensor cache.
- Arrays.fill(outputTensors, null);
}
for (int i = 0; i < inputs.length; ++i) {
@@ -130,6 +128,14 @@ final class NativeInterpreterWrapper implements AutoCloseable {
run(interpreterHandle, errorHandle);
long inferenceDurationNanoseconds = System.nanoTime() - inferenceStartNanos;
+ // Allocation can trigger dynamic resizing of output tensors, so refresh all output shapes.
+ if (needsAllocation) {
+ for (int i = 0; i < outputTensors.length; ++i) {
+ if (outputTensors[i] != null) {
+ outputTensors[i].refreshShape();
+ }
+ }
+ }
for (Map.Entry<Integer, Object> output : outputs.entrySet()) {
getOutputTensor(output.getKey()).copyTo(output.getValue());
}
@@ -144,8 +150,9 @@ final class NativeInterpreterWrapper implements AutoCloseable {
void resizeInput(int idx, int[] dims) {
if (resizeInput(interpreterHandle, errorHandle, idx, dims)) {
isMemoryAllocated = false;
- // Resizing will invalidate the Tensor's shape, so invalidate the Tensor handle.
- inputTensors[idx] = null;
+ if (inputTensors[idx] != null) {
+ inputTensors[idx].refreshShape();
+ }
}
}
@@ -230,6 +237,11 @@ final class NativeInterpreterWrapper implements AutoCloseable {
return getOutputQuantizationScale(interpreterHandle, index);
}
+ /** Gets the number of input tensors. */
+ int getInputTensorCount() {
+ return inputTensors.length;
+ }
+
/**
* Gets the input {@link Tensor} for the provided input index.
*
@@ -247,6 +259,11 @@ final class NativeInterpreterWrapper implements AutoCloseable {
return inputTensor;
}
+ /** Gets the number of output tensors. */
+ int getOutputTensorCount() {
+ return inputTensors.length;
+ }
+
/**
* Gets the output {@link Tensor} for the provided output index.
*
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java
index 2403570c52..f174178d98 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java
@@ -26,7 +26,7 @@ import java.util.Arrays;
* <p>The native handle of a {@code Tensor} belongs to {@code NativeInterpreterWrapper}, thus not
* needed to be closed here.
*/
-final class Tensor {
+public final class Tensor {
static Tensor fromHandle(long nativeHandle) {
return new Tensor(nativeHandle);
@@ -37,11 +37,26 @@ final class Tensor {
return dtype;
}
+ /**
+ * Returns the number of dimensions (sometimes referred to as <a
+ * href="https://www.tensorflow.org/resources/dims_types.html#rank">rank</a>) of the Tensor.
+ *
+ * <p>Will be 0 for a scalar, 1 for a vector, 2 for a matrix, 3 for a 3-dimensional tensor etc.
+ */
+ public int numDimensions() {
+ return shapeCopy.length;
+ }
+
/** Returns the size, in bytes, of the tensor data. */
public int numBytes() {
return numBytes(nativeHandle);
}
+ /** Returns the number of elements in a flattened (1-D) view of the tensor. */
+ public int numElements() {
+ return computeNumElements(shapeCopy);
+ }
+
/**
* Returns the <a href="https://www.tensorflow.org/resources/dims_types.html#shape">shape</a> of
* the Tensor, i.e., the sizes of each dimension.
@@ -103,13 +118,22 @@ final class Tensor {
if (isByteBuffer(input)) {
return null;
}
- int[] inputShape = shapeOf(input);
+ int[] inputShape = computeShapeOf(input);
if (Arrays.equals(shapeCopy, inputShape)) {
return null;
}
return inputShape;
}
+ /**
+ * Forces a refresh of the tensor's cached shape.
+ *
+ * <p>This is useful if the tensor is resized or has a dynamic shape.
+ */
+ void refreshShape() {
+ this.shapeCopy = shape(nativeHandle);
+ }
+
/** Returns the type of the data. */
static DataType dataTypeOf(Object o) {
if (o != null) {
@@ -132,22 +156,31 @@ final class Tensor {
}
/** Returns the shape of an object as an int array. */
- static int[] shapeOf(Object o) {
- int size = numDimensions(o);
+ static int[] computeShapeOf(Object o) {
+ int size = computeNumDimensions(o);
int[] dimensions = new int[size];
fillShape(o, 0, dimensions);
return dimensions;
}
+ /** Returns the number of elements in a flattened (1-D) view of the tensor's shape. */
+ static int computeNumElements(int[] shape) {
+ int n = 1;
+ for (int i = 0; i < shape.length; ++i) {
+ n *= shape[i];
+ }
+ return n;
+ }
+
/** Returns the number of dimensions of a multi-dimensional array, otherwise 0. */
- static int numDimensions(Object o) {
+ static int computeNumDimensions(Object o) {
if (o == null || !o.getClass().isArray()) {
return 0;
}
if (Array.getLength(o) == 0) {
throw new IllegalArgumentException("Array lengths cannot be 0.");
}
- return 1 + numDimensions(Array.get(o, 0));
+ return 1 + computeNumDimensions(Array.get(o, 0));
}
/** Recursively populates the shape dimensions for a given (multi-dimensional) array. */
@@ -188,7 +221,7 @@ final class Tensor {
dtype, o.getClass().getName(), oType));
}
- int[] oShape = shapeOf(o);
+ int[] oShape = computeShapeOf(o);
if (!Arrays.equals(oShape, shapeCopy)) {
throw new IllegalArgumentException(
String.format(
@@ -204,11 +237,11 @@ final class Tensor {
private final long nativeHandle;
private final DataType dtype;
- private final int[] shapeCopy;
+ private int[] shapeCopy;
private Tensor(long nativeHandle) {
this.nativeHandle = nativeHandle;
- this.dtype = DataType.fromNumber(dtype(nativeHandle));
+ this.dtype = DataType.fromC(dtype(nativeHandle));
this.shapeCopy = shape(nativeHandle);
}
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java
index cebc944200..6d6417f895 100644
--- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java
@@ -26,9 +26,16 @@ public final class DataTypeTest {
@Test
public void testElemByteSize() {
- assertThat(DataType.FLOAT32.elemByteSize()).isEqualTo(4);
- assertThat(DataType.INT32.elemByteSize()).isEqualTo(4);
- assertThat(DataType.UINT8.elemByteSize()).isEqualTo(1);
- assertThat(DataType.INT64.elemByteSize()).isEqualTo(8);
+ assertThat(DataType.FLOAT32.byteSize()).isEqualTo(4);
+ assertThat(DataType.INT32.byteSize()).isEqualTo(4);
+ assertThat(DataType.UINT8.byteSize()).isEqualTo(1);
+ assertThat(DataType.INT64.byteSize()).isEqualTo(8);
+ }
+
+ @Test
+ public void testConversion() {
+ for (DataType dataType : DataType.values()) {
+ assertThat(DataType.fromC(dataType.c())).isEqualTo(dataType);
+ }
}
}
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
index d66a73db94..9070b788b6 100644
--- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
@@ -47,6 +47,10 @@ public final class InterpreterTest {
public void testInterpreter() throws Exception {
Interpreter interpreter = new Interpreter(MODEL_FILE);
assertThat(interpreter).isNotNull();
+ assertThat(interpreter.getInputTensorCount()).isEqualTo(1);
+ assertThat(interpreter.getInputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
+ assertThat(interpreter.getOutputTensorCount()).isEqualTo(1);
+ assertThat(interpreter.getOutputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
interpreter.close();
}
@@ -183,6 +187,19 @@ public final class InterpreterTest {
}
@Test
+ public void testResizeInput() {
+ try (Interpreter interpreter = new Interpreter(MODEL_FILE)) {
+ int[] inputDims = {1};
+ interpreter.resizeInput(0, inputDims);
+ assertThat(interpreter.getInputTensor(0).shape()).isEqualTo(inputDims);
+ ByteBuffer input = ByteBuffer.allocateDirect(4).order(ByteOrder.nativeOrder());
+ ByteBuffer output = ByteBuffer.allocateDirect(4).order(ByteOrder.nativeOrder());
+ interpreter.run(input, output);
+ assertThat(interpreter.getOutputTensor(0).shape()).isEqualTo(inputDims);
+ }
+ }
+
+ @Test
public void testMobilenetRun() {
// Create a gray image.
float[][][][] img = new float[1][224][224][3];
@@ -199,6 +216,8 @@ public final class InterpreterTest {
Interpreter interpreter = new Interpreter(MOBILENET_MODEL_FILE);
interpreter.run(img, labels);
+ assertThat(interpreter.getInputTensor(0).shape()).isEqualTo(new int[] {1, 224, 224, 3});
+ assertThat(interpreter.getOutputTensor(0).shape()).isEqualTo(new int[] {1, 1001});
interpreter.close();
assertThat(labels[0])
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java
index 71ef044943..85ad393d89 100644
--- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java
@@ -64,6 +64,8 @@ public final class TensorTest {
assertThat(tensor.shape()).isEqualTo(expectedShape);
assertThat(tensor.dataType()).isEqualTo(DataType.FLOAT32);
assertThat(tensor.numBytes()).isEqualTo(2 * 8 * 8 * 3 * 4);
+ assertThat(tensor.numElements()).isEqualTo(2 * 8 * 8 * 3);
+ assertThat(tensor.numDimensions()).isEqualTo(4);
}
@Test
@@ -201,12 +203,12 @@ public final class TensorTest {
@Test
public void testNumDimensions() {
int scalar = 1;
- assertThat(Tensor.numDimensions(scalar)).isEqualTo(0);
+ assertThat(Tensor.computeNumDimensions(scalar)).isEqualTo(0);
int[][] array = {{2, 4}, {1, 9}};
- assertThat(Tensor.numDimensions(array)).isEqualTo(2);
+ assertThat(Tensor.computeNumDimensions(array)).isEqualTo(2);
try {
int[] emptyArray = {};
- Tensor.numDimensions(emptyArray);
+ Tensor.computeNumDimensions(emptyArray);
fail();
} catch (IllegalArgumentException e) {
assertThat(e).hasMessageThat().contains("Array lengths cannot be 0.");
@@ -214,9 +216,21 @@ public final class TensorTest {
}
@Test
+ public void testNumElements() {
+ int[] scalarShape = {};
+ assertThat(Tensor.computeNumElements(scalarShape)).isEqualTo(1);
+ int[] vectorShape = {3};
+ assertThat(Tensor.computeNumElements(vectorShape)).isEqualTo(3);
+ int[] matrixShape = {3, 4};
+ assertThat(Tensor.computeNumElements(matrixShape)).isEqualTo(12);
+ int[] degenerateShape = {3, 4, 0};
+ assertThat(Tensor.computeNumElements(degenerateShape)).isEqualTo(0);
+ }
+
+ @Test
public void testFillShape() {
int[][][] array = {{{23}, {14}, {87}}, {{12}, {42}, {31}}};
- int num = Tensor.numDimensions(array);
+ int num = Tensor.computeNumDimensions(array);
int[] shape = new int[num];
Tensor.fillShape(array, 0, shape);
assertThat(num).isEqualTo(3);