diff options
author | 2018-06-15 16:52:01 -0700 | |
---|---|---|
committer | 2018-06-15 16:55:40 -0700 | |
commit | d1daba6ac82461cd64dc070534bc613a70527520 (patch) | |
tree | 8891eb4f10f4d6864952a45eb2e39226f9c8ab72 /tensorflow/contrib/lite/java/src/test | |
parent | 23bdaed4fbcd3b335a4699f6ed02176a0b6a91c9 (diff) |
Expose Quantization params for outputs in JNI interpreter
PiperOrigin-RevId: 200795402
Diffstat (limited to 'tensorflow/contrib/lite/java/src/test')
-rw-r--r-- | tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java index 7c00d3196f..9e41cb132d 100644 --- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java @@ -41,6 +41,9 @@ public final class NativeInterpreterWrapperTest { private static final String BYTE_MODEL_PATH = "tensorflow/contrib/lite/java/src/testdata/uint8.bin"; + private static final String QUANTIZED_MODEL_PATH = + "tensorflow/contrib/lite/java/src/testdata/quantized.bin"; + private static final String INVALID_MODEL_PATH = "tensorflow/contrib/lite/java/src/testdata/invalid_model.bin"; @@ -536,4 +539,16 @@ public final class NativeInterpreterWrapperTest { assertThat(wrapper.getOutputDataType(0)).contains("byte"); wrapper.close(); } + + @Test + public void testGetOutputQuantizationParams() { + try (NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH)) { + assertThat(wrapper.getOutputQuantizationZeroPoint(0)).isEqualTo(0); + assertThat(wrapper.getOutputQuantizationScale(0)).isWithin(1e-6f).of(0.0f); + } + try (NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(QUANTIZED_MODEL_PATH)) { + assertThat(wrapper.getOutputQuantizationZeroPoint(0)).isEqualTo(127); + assertThat(wrapper.getOutputQuantizationScale(0)).isWithin(1e-6f).of(0.25f); + } + } } |