diff options
Diffstat (limited to 'tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java')
-rw-r--r-- | tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java | 152 |
1 files changed, 137 insertions, 15 deletions
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java index 94b6632bb8..71ef044943 100644 --- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java @@ -18,6 +18,10 @@ package org.tensorflow.lite; import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.fail; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.HashMap; +import java.util.Map; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -32,7 +36,7 @@ public final class TensorTest { "tensorflow/contrib/lite/java/src/testdata/add.bin"; private NativeInterpreterWrapper wrapper; - private long nativeHandle; + private Tensor tensor; @Before public void setUp() { @@ -42,8 +46,10 @@ public final class TensorTest { float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; float[][][][] fourD = {threeD, threeD}; Object[] inputs = {fourD}; - Tensor[] outputs = wrapper.run(inputs); - nativeHandle = outputs[0].nativeHandle; + Map<Integer, Object> outputs = new HashMap<>(); + outputs.put(0, new float[2][8][8][3]); + wrapper.run(inputs, outputs); + tensor = wrapper.getOutputTensor(0); } @After @@ -52,17 +58,16 @@ public final class TensorTest { } @Test - public void testFromHandle() throws Exception { - Tensor tensor = Tensor.fromHandle(nativeHandle); + public void testBasic() throws Exception { assertThat(tensor).isNotNull(); int[] expectedShape = {2, 8, 8, 3}; - assertThat(tensor.shapeCopy).isEqualTo(expectedShape); - assertThat(tensor.dtype).isEqualTo(DataType.FLOAT32); + assertThat(tensor.shape()).isEqualTo(expectedShape); + assertThat(tensor.dataType()).isEqualTo(DataType.FLOAT32); + assertThat(tensor.numBytes()).isEqualTo(2 * 8 * 8 * 3 * 4); } @Test public void testCopyTo() { - Tensor tensor = Tensor.fromHandle(nativeHandle); float[][][][] parsedOutputs = new float[2][8][8][3]; tensor.copyTo(parsedOutputs); float[] outputOneD = parsedOutputs[0][0][0]; @@ -71,8 +76,31 @@ public final class TensorTest { } @Test + public void testCopyToByteBuffer() { + ByteBuffer parsedOutput = + ByteBuffer.allocateDirect(2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder()); + tensor.copyTo(parsedOutput); + assertThat(parsedOutput.position()).isEqualTo(2 * 8 * 8 * 3 * 4); + float[] outputOneD = { + parsedOutput.getFloat(0), parsedOutput.getFloat(4), parsedOutput.getFloat(8) + }; + float[] expected = {3.69f, 19.62f, 23.43f}; + assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); + } + + @Test + public void testCopyToInvalidByteBuffer() { + ByteBuffer parsedOutput = ByteBuffer.allocateDirect(3 * 4).order(ByteOrder.nativeOrder()); + try { + tensor.copyTo(parsedOutput); + fail(); + } catch (IllegalArgumentException e) { + // Expected. + } + } + + @Test public void testCopyToWrongType() { - Tensor tensor = Tensor.fromHandle(nativeHandle); int[][][][] parsedOutputs = new int[2][8][8][3]; try { tensor.copyTo(parsedOutputs); @@ -81,15 +109,13 @@ public final class TensorTest { assertThat(e) .hasMessageThat() .contains( - "Cannot convert an TensorFlowLite tensor with type " - + "FLOAT32 to a Java object of type [[[[I (which is compatible with the TensorFlowLite " - + "type INT32)"); + "Cannot convert between a TensorFlowLite tensor with type FLOAT32 and a Java object " + + "of type [[[[I (which is compatible with the TensorFlowLite type INT32)"); } } @Test public void testCopyToWrongShape() { - Tensor tensor = Tensor.fromHandle(nativeHandle); float[][][][] parsedOutputs = new float[1][8][8][3]; try { tensor.copyTo(parsedOutputs); @@ -98,8 +124,104 @@ public final class TensorTest { assertThat(e) .hasMessageThat() .contains( - "Shape of output target [1, 8, 8, 3] does not match " - + "with the shape of the Tensor [2, 8, 8, 3]."); + "Cannot copy between a TensorFlowLite tensor with shape [2, 8, 8, 3] " + + "and a Java object with shape [1, 8, 8, 3]."); + } + } + + @Test + public void testSetTo() { + float[][][][] input = new float[2][8][8][3]; + float[][][][] output = new float[2][8][8][3]; + ByteBuffer inputByteBuffer = + ByteBuffer.allocateDirect(2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder()); + + input[0][0][0][0] = 2.0f; + tensor.setTo(input); + tensor.copyTo(output); + assertThat(output[0][0][0][0]).isEqualTo(2.0f); + + inputByteBuffer.putFloat(0, 3.0f); + tensor.setTo(inputByteBuffer); + tensor.copyTo(output); + assertThat(output[0][0][0][0]).isEqualTo(3.0f); + } + + @Test + public void testSetToInvalidByteBuffer() { + ByteBuffer input = ByteBuffer.allocateDirect(3 * 4).order(ByteOrder.nativeOrder()); + try { + tensor.setTo(input); + fail(); + } catch (IllegalArgumentException e) { + // Success. + } + } + + @Test + public void testGetInputShapeIfDifferent() { + ByteBuffer bytBufferInput = ByteBuffer.allocateDirect(3 * 4).order(ByteOrder.nativeOrder()); + assertThat(tensor.getInputShapeIfDifferent(bytBufferInput)).isNull(); + + float[][][][] sameShapeInput = new float[2][8][8][3]; + assertThat(tensor.getInputShapeIfDifferent(sameShapeInput)).isNull(); + + float[][][][] differentShapeInput = new float[1][8][8][3]; + assertThat(tensor.getInputShapeIfDifferent(differentShapeInput)) + .isEqualTo(new int[] {1, 8, 8, 3}); + } + + @Test + public void testDataTypeOf() { + float[] testEmptyArray = {}; + DataType dataType = Tensor.dataTypeOf(testEmptyArray); + assertThat(dataType).isEqualTo(DataType.FLOAT32); + float[] testFloatArray = {0.783f, 0.251f}; + dataType = Tensor.dataTypeOf(testFloatArray); + assertThat(dataType).isEqualTo(DataType.FLOAT32); + float[][] testMultiDimArray = {testFloatArray, testFloatArray, testFloatArray}; + dataType = Tensor.dataTypeOf(testFloatArray); + assertThat(dataType).isEqualTo(DataType.FLOAT32); + try { + double[] testDoubleArray = {0.783, 0.251}; + Tensor.dataTypeOf(testDoubleArray); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().contains("cannot resolve DataType of"); } + try { + Float[] testBoxedArray = {0.783f, 0.251f}; + Tensor.dataTypeOf(testBoxedArray); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().contains("cannot resolve DataType of [Ljava.lang.Float;"); + } + } + + @Test + public void testNumDimensions() { + int scalar = 1; + assertThat(Tensor.numDimensions(scalar)).isEqualTo(0); + int[][] array = {{2, 4}, {1, 9}}; + assertThat(Tensor.numDimensions(array)).isEqualTo(2); + try { + int[] emptyArray = {}; + Tensor.numDimensions(emptyArray); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().contains("Array lengths cannot be 0."); + } + } + + @Test + public void testFillShape() { + int[][][] array = {{{23}, {14}, {87}}, {{12}, {42}, {31}}}; + int num = Tensor.numDimensions(array); + int[] shape = new int[num]; + Tensor.fillShape(array, 0, shape); + assertThat(num).isEqualTo(3); + assertThat(shape[0]).isEqualTo(2); + assertThat(shape[1]).isEqualTo(3); + assertThat(shape[2]).isEqualTo(1); } } |