diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-05-21 17:03:40 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-21 17:06:47 -0700 |
commit | 006e293c7bc8a30ea8f9618cd305bc8719a96638 (patch) | |
tree | 8ff96036e378a59c5e01ef8852a1ec49714dd9a2 /tensorflow/contrib/lite/java/src/test | |
parent | 0ad7f20ed6876809a2b804365293a5c21dbcd374 (diff) |
Supports initializing an Interpreter with a direct ByteBuffer of nativeOrder()
that contains bytes content of a valid tflite model.
PiperOrigin-RevId: 197485253
Diffstat (limited to 'tensorflow/contrib/lite/java/src/test')
-rw-r--r-- | tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java | 45 |
1 files changed, 45 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java index 210d943724..82007a6ab5 100644 --- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java @@ -19,6 +19,8 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.fail; import java.io.File; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.nio.MappedByteBuffer; import java.nio.channels.FileChannel; import java.nio.file.Files; @@ -70,6 +72,49 @@ public final class InterpreterTest { } @Test + public void testRunWithDirectByteBufferModel() throws Exception { + Path path = MODEL_FILE.toPath(); + FileChannel fileChannel = + (FileChannel) Files.newByteChannel(path, EnumSet.of(StandardOpenOption.READ)); + ByteBuffer byteBuffer = ByteBuffer.allocateDirect((int) fileChannel.size()); + byteBuffer.order(ByteOrder.nativeOrder()); + fileChannel.read(byteBuffer); + Interpreter interpreter = new Interpreter(byteBuffer); + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + float[][][][] parsedOutputs = new float[2][8][8][3]; + interpreter.run(fourD, parsedOutputs); + float[] outputOneD = parsedOutputs[0][0][0]; + float[] expected = {3.69f, 19.62f, 23.43f}; + assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); + interpreter.close(); + fileChannel.close(); + } + + @Test + public void testRunWithInvalidByteBufferModel() throws Exception { + Path path = MODEL_FILE.toPath(); + FileChannel fileChannel = + (FileChannel) Files.newByteChannel(path, EnumSet.of(StandardOpenOption.READ)); + ByteBuffer byteBuffer = ByteBuffer.allocate((int) fileChannel.size()); + byteBuffer.order(ByteOrder.nativeOrder()); + fileChannel.read(byteBuffer); + try { + Interpreter interpreter = new Interpreter(byteBuffer); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains( + "Model ByteBuffer should be either a MappedByteBuffer" + + " of the model file, or a direct ByteBuffer using ByteOrder.nativeOrder()"); + } + fileChannel.close(); + } + + @Test public void testRun() { Interpreter interpreter = new Interpreter(MODEL_FILE); Float[] oneD = {1.23f, 6.54f, 7.81f}; |