diff options
author | 2018-08-20 13:27:38 -0700 | |
---|---|---|
committer | 2018-08-20 13:32:43 -0700 | |
commit | e8894bdcda6c7fb899939406ff4f320d2c59b208 (patch) | |
tree | 9cbfa999ad57a27d085144a6efadfd9e3216e4b6 /tensorflow/contrib/lite/java/src/test | |
parent | 600caf99897e82cd0db8665acca5e7630ec1a292 (diff) |
Extend Java Interpreter API for TensorFlow Lite
Expose simple Tensor and DataType Java classes that can be used for
basic introspection. Note that this change does not allow direct
mutation of Tensor objects. The client must still use the
Interpreter.invoke() API for injecting and retrieving Tensor
data.
PiperOrigin-RevId: 209473412
Diffstat (limited to 'tensorflow/contrib/lite/java/src/test')
3 files changed, 48 insertions, 8 deletions
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java index cebc944200..6d6417f895 100644 --- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java @@ -26,9 +26,16 @@ public final class DataTypeTest { @Test public void testElemByteSize() { - assertThat(DataType.FLOAT32.elemByteSize()).isEqualTo(4); - assertThat(DataType.INT32.elemByteSize()).isEqualTo(4); - assertThat(DataType.UINT8.elemByteSize()).isEqualTo(1); - assertThat(DataType.INT64.elemByteSize()).isEqualTo(8); + assertThat(DataType.FLOAT32.byteSize()).isEqualTo(4); + assertThat(DataType.INT32.byteSize()).isEqualTo(4); + assertThat(DataType.UINT8.byteSize()).isEqualTo(1); + assertThat(DataType.INT64.byteSize()).isEqualTo(8); + } + + @Test + public void testConversion() { + for (DataType dataType : DataType.values()) { + assertThat(DataType.fromC(dataType.c())).isEqualTo(dataType); + } } } diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java index d66a73db94..9070b788b6 100644 --- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java @@ -47,6 +47,10 @@ public final class InterpreterTest { public void testInterpreter() throws Exception { Interpreter interpreter = new Interpreter(MODEL_FILE); assertThat(interpreter).isNotNull(); + assertThat(interpreter.getInputTensorCount()).isEqualTo(1); + assertThat(interpreter.getInputTensor(0).dataType()).isEqualTo(DataType.FLOAT32); + assertThat(interpreter.getOutputTensorCount()).isEqualTo(1); + assertThat(interpreter.getOutputTensor(0).dataType()).isEqualTo(DataType.FLOAT32); interpreter.close(); } @@ -183,6 +187,19 @@ public final class InterpreterTest { } @Test + public void testResizeInput() { + try (Interpreter interpreter = new Interpreter(MODEL_FILE)) { + int[] inputDims = {1}; + interpreter.resizeInput(0, inputDims); + assertThat(interpreter.getInputTensor(0).shape()).isEqualTo(inputDims); + ByteBuffer input = ByteBuffer.allocateDirect(4).order(ByteOrder.nativeOrder()); + ByteBuffer output = ByteBuffer.allocateDirect(4).order(ByteOrder.nativeOrder()); + interpreter.run(input, output); + assertThat(interpreter.getOutputTensor(0).shape()).isEqualTo(inputDims); + } + } + + @Test public void testMobilenetRun() { // Create a gray image. float[][][][] img = new float[1][224][224][3]; @@ -199,6 +216,8 @@ public final class InterpreterTest { Interpreter interpreter = new Interpreter(MOBILENET_MODEL_FILE); interpreter.run(img, labels); + assertThat(interpreter.getInputTensor(0).shape()).isEqualTo(new int[] {1, 224, 224, 3}); + assertThat(interpreter.getOutputTensor(0).shape()).isEqualTo(new int[] {1, 1001}); interpreter.close(); assertThat(labels[0]) diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java index 71ef044943..85ad393d89 100644 --- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java @@ -64,6 +64,8 @@ public final class TensorTest { assertThat(tensor.shape()).isEqualTo(expectedShape); assertThat(tensor.dataType()).isEqualTo(DataType.FLOAT32); assertThat(tensor.numBytes()).isEqualTo(2 * 8 * 8 * 3 * 4); + assertThat(tensor.numElements()).isEqualTo(2 * 8 * 8 * 3); + assertThat(tensor.numDimensions()).isEqualTo(4); } @Test @@ -201,12 +203,12 @@ public final class TensorTest { @Test public void testNumDimensions() { int scalar = 1; - assertThat(Tensor.numDimensions(scalar)).isEqualTo(0); + assertThat(Tensor.computeNumDimensions(scalar)).isEqualTo(0); int[][] array = {{2, 4}, {1, 9}}; - assertThat(Tensor.numDimensions(array)).isEqualTo(2); + assertThat(Tensor.computeNumDimensions(array)).isEqualTo(2); try { int[] emptyArray = {}; - Tensor.numDimensions(emptyArray); + Tensor.computeNumDimensions(emptyArray); fail(); } catch (IllegalArgumentException e) { assertThat(e).hasMessageThat().contains("Array lengths cannot be 0."); @@ -214,9 +216,21 @@ public final class TensorTest { } @Test + public void testNumElements() { + int[] scalarShape = {}; + assertThat(Tensor.computeNumElements(scalarShape)).isEqualTo(1); + int[] vectorShape = {3}; + assertThat(Tensor.computeNumElements(vectorShape)).isEqualTo(3); + int[] matrixShape = {3, 4}; + assertThat(Tensor.computeNumElements(matrixShape)).isEqualTo(12); + int[] degenerateShape = {3, 4, 0}; + assertThat(Tensor.computeNumElements(degenerateShape)).isEqualTo(0); + } + + @Test public void testFillShape() { int[][][] array = {{{23}, {14}, {87}}, {{12}, {42}, {31}}}; - int num = Tensor.numDimensions(array); + int num = Tensor.computeNumDimensions(array); int[] shape = new int[num]; Tensor.fillShape(array, 0, shape); assertThat(num).isEqualTo(3); |