diff options
author | 2018-03-21 19:12:18 -0700 | |
---|---|---|
committer | 2018-03-21 19:14:52 -0700 | |
commit | 73bd57d80111dc957d117b6ae98bc2354f766604 (patch) | |
tree | 101b4f5ec9ac869041eb4f2455384085c12d3af3 /tensorflow/contrib/lite/python/interpreter_test.py | |
parent | 61aa925ebaa69b9526cc67384fcde3fa42c9e6f1 (diff) |
Add tensor quantization info to python wrapper
PiperOrigin-RevId: 190005998
Diffstat (limited to 'tensorflow/contrib/lite/python/interpreter_test.py')
-rw-r--r-- | tensorflow/contrib/lite/python/interpreter_test.py | 4 |
1 files changed, 4 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/python/interpreter_test.py b/tensorflow/contrib/lite/python/interpreter_test.py index e85390c56c..bf124410f3 100644 --- a/tensorflow/contrib/lite/python/interpreter_test.py +++ b/tensorflow/contrib/lite/python/interpreter_test.py @@ -39,12 +39,14 @@ class InterpreterTest(test_util.TensorFlowTestCase): self.assertEqual('input', input_details[0]['name']) self.assertEqual(np.float32, input_details[0]['dtype']) self.assertTrue(([1, 4] == input_details[0]['shape']).all()) + self.assertEqual((0.0, 0), input_details[0]['quantization']) output_details = interpreter.get_output_details() self.assertEqual(1, len(output_details)) self.assertEqual('output', output_details[0]['name']) self.assertEqual(np.float32, output_details[0]['dtype']) self.assertTrue(([1, 4] == output_details[0]['shape']).all()) + self.assertEqual((0.0, 0), output_details[0]['quantization']) test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32) expected_output = np.array([[4.0, 3.0, 2.0, 1.0]], dtype=np.float32) @@ -67,12 +69,14 @@ class InterpreterTest(test_util.TensorFlowTestCase): self.assertEqual('input', input_details[0]['name']) self.assertEqual(np.uint8, input_details[0]['dtype']) self.assertTrue(([1, 4] == input_details[0]['shape']).all()) + self.assertEqual((1.0, 0), input_details[0]['quantization']) output_details = interpreter.get_output_details() self.assertEqual(1, len(output_details)) self.assertEqual('output', output_details[0]['name']) self.assertEqual(np.uint8, output_details[0]['dtype']) self.assertTrue(([1, 4] == output_details[0]['shape']).all()) + self.assertEqual((1.0, 0), output_details[0]['quantization']) test_input = np.array([[1, 2, 3, 4]], dtype=np.uint8) expected_output = np.array([[4, 3, 2, 1]], dtype=np.uint8) |