aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/java/src/test
diff options
context:
space:
mode:
authorGravatar Jared Duke <jdduke@google.com>2018-08-20 13:27:38 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-20 13:32:43 -0700
commite8894bdcda6c7fb899939406ff4f320d2c59b208 (patch)
tree9cbfa999ad57a27d085144a6efadfd9e3216e4b6 /tensorflow/contrib/lite/java/src/test
parent600caf99897e82cd0db8665acca5e7630ec1a292 (diff)
Extend Java Interpreter API for TensorFlow Lite
Expose simple Tensor and DataType Java classes that can be used for basic introspection. Note that this change does not allow direct mutation of Tensor objects. The client must still use the Interpreter.invoke() API for injecting and retrieving Tensor data. PiperOrigin-RevId: 209473412
Diffstat (limited to 'tensorflow/contrib/lite/java/src/test')
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java15
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java19
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java22
3 files changed, 48 insertions, 8 deletions
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java
index cebc944200..6d6417f895 100644
--- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java
@@ -26,9 +26,16 @@ public final class DataTypeTest {
@Test
public void testElemByteSize() {
- assertThat(DataType.FLOAT32.elemByteSize()).isEqualTo(4);
- assertThat(DataType.INT32.elemByteSize()).isEqualTo(4);
- assertThat(DataType.UINT8.elemByteSize()).isEqualTo(1);
- assertThat(DataType.INT64.elemByteSize()).isEqualTo(8);
+ assertThat(DataType.FLOAT32.byteSize()).isEqualTo(4);
+ assertThat(DataType.INT32.byteSize()).isEqualTo(4);
+ assertThat(DataType.UINT8.byteSize()).isEqualTo(1);
+ assertThat(DataType.INT64.byteSize()).isEqualTo(8);
+ }
+
+ @Test
+ public void testConversion() {
+ for (DataType dataType : DataType.values()) {
+ assertThat(DataType.fromC(dataType.c())).isEqualTo(dataType);
+ }
}
}
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
index d66a73db94..9070b788b6 100644
--- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
@@ -47,6 +47,10 @@ public final class InterpreterTest {
public void testInterpreter() throws Exception {
Interpreter interpreter = new Interpreter(MODEL_FILE);
assertThat(interpreter).isNotNull();
+ assertThat(interpreter.getInputTensorCount()).isEqualTo(1);
+ assertThat(interpreter.getInputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
+ assertThat(interpreter.getOutputTensorCount()).isEqualTo(1);
+ assertThat(interpreter.getOutputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
interpreter.close();
}
@@ -183,6 +187,19 @@ public final class InterpreterTest {
}
@Test
+ public void testResizeInput() {
+ try (Interpreter interpreter = new Interpreter(MODEL_FILE)) {
+ int[] inputDims = {1};
+ interpreter.resizeInput(0, inputDims);
+ assertThat(interpreter.getInputTensor(0).shape()).isEqualTo(inputDims);
+ ByteBuffer input = ByteBuffer.allocateDirect(4).order(ByteOrder.nativeOrder());
+ ByteBuffer output = ByteBuffer.allocateDirect(4).order(ByteOrder.nativeOrder());
+ interpreter.run(input, output);
+ assertThat(interpreter.getOutputTensor(0).shape()).isEqualTo(inputDims);
+ }
+ }
+
+ @Test
public void testMobilenetRun() {
// Create a gray image.
float[][][][] img = new float[1][224][224][3];
@@ -199,6 +216,8 @@ public final class InterpreterTest {
Interpreter interpreter = new Interpreter(MOBILENET_MODEL_FILE);
interpreter.run(img, labels);
+ assertThat(interpreter.getInputTensor(0).shape()).isEqualTo(new int[] {1, 224, 224, 3});
+ assertThat(interpreter.getOutputTensor(0).shape()).isEqualTo(new int[] {1, 1001});
interpreter.close();
assertThat(labels[0])
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java
index 71ef044943..85ad393d89 100644
--- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java
@@ -64,6 +64,8 @@ public final class TensorTest {
assertThat(tensor.shape()).isEqualTo(expectedShape);
assertThat(tensor.dataType()).isEqualTo(DataType.FLOAT32);
assertThat(tensor.numBytes()).isEqualTo(2 * 8 * 8 * 3 * 4);
+ assertThat(tensor.numElements()).isEqualTo(2 * 8 * 8 * 3);
+ assertThat(tensor.numDimensions()).isEqualTo(4);
}
@Test
@@ -201,12 +203,12 @@ public final class TensorTest {
@Test
public void testNumDimensions() {
int scalar = 1;
- assertThat(Tensor.numDimensions(scalar)).isEqualTo(0);
+ assertThat(Tensor.computeNumDimensions(scalar)).isEqualTo(0);
int[][] array = {{2, 4}, {1, 9}};
- assertThat(Tensor.numDimensions(array)).isEqualTo(2);
+ assertThat(Tensor.computeNumDimensions(array)).isEqualTo(2);
try {
int[] emptyArray = {};
- Tensor.numDimensions(emptyArray);
+ Tensor.computeNumDimensions(emptyArray);
fail();
} catch (IllegalArgumentException e) {
assertThat(e).hasMessageThat().contains("Array lengths cannot be 0.");
@@ -214,9 +216,21 @@ public final class TensorTest {
}
@Test
+ public void testNumElements() {
+ int[] scalarShape = {};
+ assertThat(Tensor.computeNumElements(scalarShape)).isEqualTo(1);
+ int[] vectorShape = {3};
+ assertThat(Tensor.computeNumElements(vectorShape)).isEqualTo(3);
+ int[] matrixShape = {3, 4};
+ assertThat(Tensor.computeNumElements(matrixShape)).isEqualTo(12);
+ int[] degenerateShape = {3, 4, 0};
+ assertThat(Tensor.computeNumElements(degenerateShape)).isEqualTo(0);
+ }
+
+ @Test
public void testFillShape() {
int[][][] array = {{{23}, {14}, {87}}, {{12}, {42}, {31}}};
- int num = Tensor.numDimensions(array);
+ int num = Tensor.computeNumDimensions(array);
int[] shape = new int[num];
Tensor.fillShape(array, 0, shape);
assertThat(num).isEqualTo(3);