aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java')
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java251
1 files changed, 97 insertions, 154 deletions
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
index 9e41cb132d..9c4a5acd79 100644
--- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
@@ -20,6 +20,8 @@ import static org.junit.Assert.fail;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
+import java.util.HashMap;
+import java.util.Map;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -101,10 +103,10 @@ public final class NativeInterpreterWrapperTest {
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
- Tensor[] outputs = wrapper.run(inputs);
- assertThat(outputs.length).isEqualTo(1);
float[][][][] parsedOutputs = new float[2][8][8][3];
- outputs[0].copyTo(parsedOutputs);
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
+ wrapper.run(inputs, outputs);
float[] outputOneD = parsedOutputs[0][0][0];
float[] expected = {3.69f, -19.62f, 23.43f};
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
@@ -112,6 +114,27 @@ public final class NativeInterpreterWrapperTest {
}
@Test
+ public void testRunWithBufferOutput() {
+ try (NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH)) {
+ float[] oneD = {1.23f, -6.54f, 7.81f};
+ float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
+ float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
+ float[][][][] fourD = {threeD, threeD};
+ Object[] inputs = {fourD};
+ ByteBuffer parsedOutput =
+ ByteBuffer.allocateDirect(2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder());
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutput);
+ wrapper.run(inputs, outputs);
+ float[] outputOneD = {
+ parsedOutput.getFloat(0), parsedOutput.getFloat(4), parsedOutput.getFloat(8)
+ };
+ float[] expected = {3.69f, -19.62f, 23.43f};
+ assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
+ }
+ }
+
+ @Test
public void testRunWithInputsOfSameDims() {
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
float[] oneD = {1.23f, -6.54f, 7.81f};
@@ -119,17 +142,16 @@ public final class NativeInterpreterWrapperTest {
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
- Tensor[] outputs = wrapper.run(inputs);
- assertThat(outputs.length).isEqualTo(1);
float[][][][] parsedOutputs = new float[2][8][8][3];
- outputs[0].copyTo(parsedOutputs);
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
+ wrapper.run(inputs, outputs);
float[] outputOneD = parsedOutputs[0][0][0];
float[] expected = {3.69f, -19.62f, 23.43f};
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
- outputs = wrapper.run(inputs);
- assertThat(outputs.length).isEqualTo(1);
parsedOutputs = new float[2][8][8][3];
- outputs[0].copyTo(parsedOutputs);
+ outputs.put(0, parsedOutputs);
+ wrapper.run(inputs, outputs);
outputOneD = parsedOutputs[0][0][0];
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
wrapper.close();
@@ -143,10 +165,10 @@ public final class NativeInterpreterWrapperTest {
int[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
int[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
- Tensor[] outputs = wrapper.run(inputs);
- assertThat(outputs.length).isEqualTo(1);
int[][][][] parsedOutputs = new int[2][4][4][12];
- outputs[0].copyTo(parsedOutputs);
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
+ wrapper.run(inputs, outputs);
int[] outputOneD = parsedOutputs[0][0][0];
int[] expected = {3, 7, -4, 3, 7, -4, 3, 7, -4, 3, 7, -4};
assertThat(outputOneD).isEqualTo(expected);
@@ -161,10 +183,10 @@ public final class NativeInterpreterWrapperTest {
long[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
long[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
- Tensor[] outputs = wrapper.run(inputs);
- assertThat(outputs.length).isEqualTo(1);
long[][][][] parsedOutputs = new long[2][4][4][12];
- outputs[0].copyTo(parsedOutputs);
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
+ wrapper.run(inputs, outputs);
long[] outputOneD = parsedOutputs[0][0][0];
long[] expected = {-892834092L, 923423L, 2123918239018L, -892834092L, 923423L, 2123918239018L,
-892834092L, 923423L, 2123918239018L, -892834092L, 923423L, 2123918239018L};
@@ -182,10 +204,10 @@ public final class NativeInterpreterWrapperTest {
Object[] inputs = {fourD};
int[] inputDims = {2, 8, 8, 3};
wrapper.resizeInput(0, inputDims);
- Tensor[] outputs = wrapper.run(inputs);
- assertThat(outputs.length).isEqualTo(1);
byte[][][][] parsedOutputs = new byte[2][4][4][12];
- outputs[0].copyTo(parsedOutputs);
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
+ wrapper.run(inputs, outputs);
byte[] outputOneD = parsedOutputs[0][0][0];
byte[] expected = {(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0,
(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0};
@@ -208,13 +230,14 @@ public final class NativeInterpreterWrapperTest {
}
}
}
+ bbuf.rewind();
Object[] inputs = {bbuf};
int[] inputDims = {2, 8, 8, 3};
wrapper.resizeInput(0, inputDims);
- Tensor[] outputs = wrapper.run(inputs);
- assertThat(outputs.length).isEqualTo(1);
byte[][][][] parsedOutputs = new byte[2][4][4][12];
- outputs[0].copyTo(parsedOutputs);
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
+ wrapper.run(inputs, outputs);
byte[] outputOneD = parsedOutputs[0][0][0];
byte[] expected = {
(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0,
@@ -240,21 +263,22 @@ public final class NativeInterpreterWrapperTest {
}
}
Object[] inputs = {bbuf};
+ float[][][][] parsedOutputs = new float[4][8][8][3];
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
try {
- wrapper.run(inputs);
+ wrapper.run(inputs, outputs);
fail();
} catch (IllegalArgumentException e) {
assertThat(e)
.hasMessageThat()
.contains(
- "Failed to get input dimensions. 0-th input should have 768 bytes, but found 3072 bytes");
+ "Cannot convert between a TensorFlowLite buffer with 768 bytes and a "
+ + "ByteBuffer with 3072 bytes.");
}
int[] inputDims = {4, 8, 8, 3};
wrapper.resizeInput(0, inputDims);
- Tensor[] outputs = wrapper.run(inputs);
- assertThat(outputs.length).isEqualTo(1);
- float[][][][] parsedOutputs = new float[4][8][8][3];
- outputs[0].copyTo(parsedOutputs);
+ wrapper.run(inputs, outputs);
float[] outputOneD = parsedOutputs[0][0][0];
float[] expected = {3.69f, -19.62f, 23.43f};
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
@@ -267,14 +291,18 @@ public final class NativeInterpreterWrapperTest {
ByteBuffer bbuf = ByteBuffer.allocateDirect(2 * 7 * 8 * 3);
bbuf.order(ByteOrder.nativeOrder());
Object[] inputs = {bbuf};
+ Map<Integer, Object> outputs = new HashMap<>();
+ ByteBuffer parsedOutput = ByteBuffer.allocateDirect(2 * 7 * 8 * 3);
+ outputs.put(0, parsedOutput);
try {
- wrapper.run(inputs);
+ wrapper.run(inputs, outputs);
fail();
} catch (IllegalArgumentException e) {
assertThat(e)
.hasMessageThat()
.contains(
- "Failed to get input dimensions. 0-th input should have 192 bytes, but found 336 bytes.");
+ "Cannot convert between a TensorFlowLite buffer with 192 bytes and a "
+ + "ByteBuffer with 336 bytes.");
}
wrapper.close();
}
@@ -287,14 +315,18 @@ public final class NativeInterpreterWrapperTest {
int[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
int[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
+ int[][][][] parsedOutputs = new int[2][8][8][3];
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
try {
- wrapper.run(inputs);
+ wrapper.run(inputs, outputs);
fail();
} catch (IllegalArgumentException e) {
assertThat(e)
.hasMessageThat()
.contains(
- "DataType (2) of input data does not match with the DataType (1) of model inputs.");
+ "Cannot convert between a TensorFlowLite tensor with type FLOAT32 and a Java object "
+ + "of type [[[[I (which is compatible with the TensorFlowLite type INT32)");
}
wrapper.close();
}
@@ -308,8 +340,11 @@ public final class NativeInterpreterWrapperTest {
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
+ float[][][][] parsedOutputs = new float[2][8][8][3];
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
try {
- wrapper.run(inputs);
+ wrapper.run(inputs, outputs);
fail();
} catch (IllegalArgumentException e) {
assertThat(e).hasMessageThat().contains("Invalid handle to Interpreter.");
@@ -321,7 +356,7 @@ public final class NativeInterpreterWrapperTest {
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
try {
Object[] inputs = {};
- wrapper.run(inputs);
+ wrapper.run(inputs, null);
fail();
} catch (IllegalArgumentException e) {
assertThat(e).hasMessageThat().contains("Inputs should not be null or empty.");
@@ -337,11 +372,14 @@ public final class NativeInterpreterWrapperTest {
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD, fourD};
+ float[][][][] parsedOutputs = new float[2][8][8][3];
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
try {
- wrapper.run(inputs);
+ wrapper.run(inputs, outputs);
fail();
} catch (IllegalArgumentException e) {
- assertThat(e).hasMessageThat().contains("Expected num of inputs is 1 but got 2");
+ assertThat(e).hasMessageThat().contains("Invalid input Tensor index: 1");
}
wrapper.close();
}
@@ -353,13 +391,18 @@ public final class NativeInterpreterWrapperTest {
float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD};
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
Object[] inputs = {threeD};
+ float[][][][] parsedOutputs = new float[2][8][8][3];
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
try {
- wrapper.run(inputs);
+ wrapper.run(inputs, outputs);
fail();
} catch (IllegalArgumentException e) {
assertThat(e)
.hasMessageThat()
- .contains("0-th input should have 4 dimensions, but found 3 dimensions");
+ .contains(
+ "Cannot copy between a TensorFlowLite tensor with shape [8, 7, 3] and a "
+ + "Java object with shape [2, 8, 8, 3].");
}
wrapper.close();
}
@@ -372,92 +415,23 @@ public final class NativeInterpreterWrapperTest {
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
+ float[][][][] parsedOutputs = new float[2][8][8][3];
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
try {
- wrapper.run(inputs);
+ wrapper.run(inputs, outputs);
fail();
} catch (IllegalArgumentException e) {
assertThat(e)
.hasMessageThat()
- .contains("0-th input dimension should be [?,8,8,3], but found [?,8,7,3]");
+ .contains(
+ "Cannot copy between a TensorFlowLite tensor with shape [2, 8, 7, 3] and a "
+ + "Java object with shape [2, 8, 8, 3].");
}
wrapper.close();
}
@Test
- public void testNumElements() {
- int[] shape = {2, 3, 4};
- int num = NativeInterpreterWrapper.numElements(shape);
- assertThat(num).isEqualTo(24);
- shape = null;
- num = NativeInterpreterWrapper.numElements(shape);
- assertThat(num).isEqualTo(0);
- }
-
- @Test
- public void testIsNonEmtpyArray() {
- assertThat(NativeInterpreterWrapper.isNonEmptyArray(null)).isFalse();
- assertThat(NativeInterpreterWrapper.isNonEmptyArray(3.2)).isFalse();
- int[] emptyArray = {};
- assertThat(NativeInterpreterWrapper.isNonEmptyArray(emptyArray)).isFalse();
- int[] validArray = {9, 5, 2, 1};
- assertThat(NativeInterpreterWrapper.isNonEmptyArray(validArray)).isTrue();
- }
-
- @Test
- public void testDataTypeOf() {
- float[] testEmtpyArray = {};
- DataType dataType = NativeInterpreterWrapper.dataTypeOf(testEmtpyArray);
- assertThat(dataType).isEqualTo(DataType.FLOAT32);
- float[] testFloatArray = {0.783f, 0.251f};
- dataType = NativeInterpreterWrapper.dataTypeOf(testFloatArray);
- assertThat(dataType).isEqualTo(DataType.FLOAT32);
- float[][] testMultiDimArray = {testFloatArray, testFloatArray, testFloatArray};
- dataType = NativeInterpreterWrapper.dataTypeOf(testFloatArray);
- assertThat(dataType).isEqualTo(DataType.FLOAT32);
- try {
- double[] testDoubleArray = {0.783, 0.251};
- NativeInterpreterWrapper.dataTypeOf(testDoubleArray);
- fail();
- } catch (IllegalArgumentException e) {
- assertThat(e).hasMessageThat().contains("cannot resolve DataType of");
- }
- try {
- Float[] testBoxedArray = {0.783f, 0.251f};
- NativeInterpreterWrapper.dataTypeOf(testBoxedArray);
- fail();
- } catch (IllegalArgumentException e) {
- assertThat(e).hasMessageThat().contains("cannot resolve DataType of [Ljava.lang.Float;");
- }
- }
-
- @Test
- public void testNumDimensions() {
- int scalar = 1;
- assertThat(NativeInterpreterWrapper.numDimensions(scalar)).isEqualTo(0);
- int[][] array = {{2, 4}, {1, 9}};
- assertThat(NativeInterpreterWrapper.numDimensions(array)).isEqualTo(2);
- try {
- int[] emptyArray = {};
- NativeInterpreterWrapper.numDimensions(emptyArray);
- fail();
- } catch (IllegalArgumentException e) {
- assertThat(e).hasMessageThat().contains("Array lengths cannot be 0.");
- }
- }
-
- @Test
- public void testFillShape() {
- int[][][] array = {{{23}, {14}, {87}}, {{12}, {42}, {31}}};
- int num = NativeInterpreterWrapper.numDimensions(array);
- int[] shape = new int[num];
- NativeInterpreterWrapper.fillShape(array, 0, shape);
- assertThat(num).isEqualTo(3);
- assertThat(shape[0]).isEqualTo(2);
- assertThat(shape[1]).isEqualTo(3);
- assertThat(shape[2]).isEqualTo(1);
- }
-
- @Test
public void testGetInferenceLatency() {
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
float[] oneD = {1.23f, 6.54f, 7.81f};
@@ -465,8 +439,10 @@ public final class NativeInterpreterWrapperTest {
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
- Tensor[] outputs = wrapper.run(inputs);
- assertThat(outputs.length).isEqualTo(1);
+ float[][][][] parsedOutputs = new float[2][8][8][3];
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
+ wrapper.run(inputs, outputs);
assertThat(wrapper.getLastNativeInferenceDurationNanoseconds()).isGreaterThan(0L);
wrapper.close();
}
@@ -486,13 +462,14 @@ public final class NativeInterpreterWrapperTest {
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
+ float[][][][] parsedOutputs = new float[2][8][8][3];
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
try {
- wrapper.run(inputs);
+ wrapper.run(inputs, outputs);
fail();
} catch (IllegalArgumentException e) {
- assertThat(e)
- .hasMessageThat()
- .contains("0-th input dimension should be [?,8,8,3], but found [?,8,7,3]");
+ // Expected.
}
assertThat(wrapper.getLastNativeInferenceDurationNanoseconds()).isNull();
wrapper.close();
@@ -502,41 +479,7 @@ public final class NativeInterpreterWrapperTest {
public void testGetInputDims() {
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
int[] expectedDims = {1, 8, 8, 3};
- assertThat(wrapper.getInputDims(0)).isEqualTo(expectedDims);
- wrapper.close();
- }
-
- @Test
- public void testGetInputDimsOutOfRange() {
- NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
- try {
- wrapper.getInputDims(-1);
- fail();
- } catch (IllegalArgumentException e) {
- assertThat(e).hasMessageThat().contains("Out of range");
- }
- try {
- wrapper.getInputDims(1);
- fail();
- } catch (IllegalArgumentException e) {
- assertThat(e).hasMessageThat().contains("Out of range");
- }
- wrapper.close();
- }
-
- @Test
- public void testGetOutputDataType() {
- NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
- assertThat(wrapper.getOutputDataType(0)).contains("float");
- wrapper.close();
- wrapper = new NativeInterpreterWrapper(LONG_MODEL_PATH);
- assertThat(wrapper.getOutputDataType(0)).contains("long");
- wrapper.close();
- wrapper = new NativeInterpreterWrapper(INT_MODEL_PATH);
- assertThat(wrapper.getOutputDataType(0)).contains("int");
- wrapper.close();
- wrapper = new NativeInterpreterWrapper(BYTE_MODEL_PATH);
- assertThat(wrapper.getOutputDataType(0)).contains("byte");
+ assertThat(wrapper.getInputTensor(0).shape()).isEqualTo(expectedDims);
wrapper.close();
}