aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/java/src/test
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-21 17:03:40 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-21 17:06:47 -0700
commit006e293c7bc8a30ea8f9618cd305bc8719a96638 (patch)
tree8ff96036e378a59c5e01ef8852a1ec49714dd9a2 /tensorflow/contrib/lite/java/src/test
parent0ad7f20ed6876809a2b804365293a5c21dbcd374 (diff)
Supports initializing an Interpreter with a direct ByteBuffer of nativeOrder()
that contains bytes content of a valid tflite model. PiperOrigin-RevId: 197485253
Diffstat (limited to 'tensorflow/contrib/lite/java/src/test')
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java45
1 files changed, 45 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
index 210d943724..82007a6ab5 100644
--- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
@@ -19,6 +19,8 @@ import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.fail;
import java.io.File;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.file.Files;
@@ -70,6 +72,49 @@ public final class InterpreterTest {
}
@Test
+ public void testRunWithDirectByteBufferModel() throws Exception {
+ Path path = MODEL_FILE.toPath();
+ FileChannel fileChannel =
+ (FileChannel) Files.newByteChannel(path, EnumSet.of(StandardOpenOption.READ));
+ ByteBuffer byteBuffer = ByteBuffer.allocateDirect((int) fileChannel.size());
+ byteBuffer.order(ByteOrder.nativeOrder());
+ fileChannel.read(byteBuffer);
+ Interpreter interpreter = new Interpreter(byteBuffer);
+ float[] oneD = {1.23f, 6.54f, 7.81f};
+ float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
+ float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
+ float[][][][] fourD = {threeD, threeD};
+ float[][][][] parsedOutputs = new float[2][8][8][3];
+ interpreter.run(fourD, parsedOutputs);
+ float[] outputOneD = parsedOutputs[0][0][0];
+ float[] expected = {3.69f, 19.62f, 23.43f};
+ assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
+ interpreter.close();
+ fileChannel.close();
+ }
+
+ @Test
+ public void testRunWithInvalidByteBufferModel() throws Exception {
+ Path path = MODEL_FILE.toPath();
+ FileChannel fileChannel =
+ (FileChannel) Files.newByteChannel(path, EnumSet.of(StandardOpenOption.READ));
+ ByteBuffer byteBuffer = ByteBuffer.allocate((int) fileChannel.size());
+ byteBuffer.order(ByteOrder.nativeOrder());
+ fileChannel.read(byteBuffer);
+ try {
+ Interpreter interpreter = new Interpreter(byteBuffer);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e)
+ .hasMessageThat()
+ .contains(
+ "Model ByteBuffer should be either a MappedByteBuffer"
+ + " of the model file, or a direct ByteBuffer using ByteOrder.nativeOrder()");
+ }
+ fileChannel.close();
+ }
+
+ @Test
public void testRun() {
Interpreter interpreter = new Interpreter(MODEL_FILE);
Float[] oneD = {1.23f, 6.54f, 7.81f};