aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/python/interpreter_test.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-21 19:12:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-21 19:14:52 -0700
commit73bd57d80111dc957d117b6ae98bc2354f766604 (patch)
tree101b4f5ec9ac869041eb4f2455384085c12d3af3 /tensorflow/contrib/lite/python/interpreter_test.py
parent61aa925ebaa69b9526cc67384fcde3fa42c9e6f1 (diff)
Add tensor quantization info to python wrapper
PiperOrigin-RevId: 190005998
Diffstat (limited to 'tensorflow/contrib/lite/python/interpreter_test.py')
-rw-r--r--tensorflow/contrib/lite/python/interpreter_test.py4
1 files changed, 4 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/python/interpreter_test.py b/tensorflow/contrib/lite/python/interpreter_test.py
index e85390c56c..bf124410f3 100644
--- a/tensorflow/contrib/lite/python/interpreter_test.py
+++ b/tensorflow/contrib/lite/python/interpreter_test.py
@@ -39,12 +39,14 @@ class InterpreterTest(test_util.TensorFlowTestCase):
self.assertEqual('input', input_details[0]['name'])
self.assertEqual(np.float32, input_details[0]['dtype'])
self.assertTrue(([1, 4] == input_details[0]['shape']).all())
+ self.assertEqual((0.0, 0), input_details[0]['quantization'])
output_details = interpreter.get_output_details()
self.assertEqual(1, len(output_details))
self.assertEqual('output', output_details[0]['name'])
self.assertEqual(np.float32, output_details[0]['dtype'])
self.assertTrue(([1, 4] == output_details[0]['shape']).all())
+ self.assertEqual((0.0, 0), output_details[0]['quantization'])
test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32)
expected_output = np.array([[4.0, 3.0, 2.0, 1.0]], dtype=np.float32)
@@ -67,12 +69,14 @@ class InterpreterTest(test_util.TensorFlowTestCase):
self.assertEqual('input', input_details[0]['name'])
self.assertEqual(np.uint8, input_details[0]['dtype'])
self.assertTrue(([1, 4] == input_details[0]['shape']).all())
+ self.assertEqual((1.0, 0), input_details[0]['quantization'])
output_details = interpreter.get_output_details()
self.assertEqual(1, len(output_details))
self.assertEqual('output', output_details[0]['name'])
self.assertEqual(np.uint8, output_details[0]['dtype'])
self.assertTrue(([1, 4] == output_details[0]['shape']).all())
+ self.assertEqual((1.0, 0), output_details[0]['quantization'])
test_input = np.array([[1, 2, 3, 4]], dtype=np.uint8)
expected_output = np.array([[4, 3, 2, 1]], dtype=np.uint8)