aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/python/lite_test.py
diff options
context:
space:
mode:
authorGravatar Nupur Garg <nupurgarg@google.com>2018-05-31 13:58:32 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-31 14:01:07 -0700
commitd3b5b07e7810782c3760468312f9cace10b89073 (patch)
tree44a32b8659de0ebb98c3756d3dc6040198c6dd25 /tensorflow/contrib/lite/python/lite_test.py
parent89a55fef3316e0e270e0f87f71bd8c2d32443cc8 (diff)
Add attributes to TFLite Python API.
PiperOrigin-RevId: 198774775
Diffstat (limited to 'tensorflow/contrib/lite/python/lite_test.py')
-rw-r--r--tensorflow/contrib/lite/python/lite_test.py61
1 files changed, 61 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py
index 28386ecb1a..1b0cdb90ce 100644
--- a/tensorflow/contrib/lite/python/lite_test.py
+++ b/tensorflow/contrib/lite/python/lite_test.py
@@ -220,6 +220,67 @@ class FromSessionTest(test_util.TensorFlowTestCase):
graphviz_output = converter.convert()
self.assertTrue(graphviz_output)
+ def testInferenceInputType(self):
+ in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3], dtype=dtypes.uint8)
+ out_tensor = in_tensor + in_tensor
+ sess = session.Session()
+
+ # Convert model and ensure model is not None.
+ converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ converter.inference_input_type = lite_constants.QUANTIZED_UINT8
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertEqual('Placeholder', input_details[0]['name'])
+ self.assertEqual(np.uint8, input_details[0]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
+ self.assertEqual((0., 0.), input_details[0]['quantization'])
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(1, len(output_details))
+ self.assertEqual('add', output_details[0]['name'])
+ self.assertEqual(np.uint8, output_details[0]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
+ self.assertEqual((0., 0.), input_details[0]['quantization'])
+
+ def testDefaultRangesStats(self):
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ out_tensor = in_tensor + in_tensor
+ sess = session.Session()
+
+ # Convert model and ensure model is not None.
+ converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ converter.inference_type = lite_constants.QUANTIZED_UINT8
+ converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev
+ converter.default_ranges_stats = (0, 6) # min, max
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertEqual('Placeholder', input_details[0]['name'])
+ self.assertEqual(np.uint8, input_details[0]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
+ self.assertEqual((1., 0.), input_details[0]['quantization'])
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(1, len(output_details))
+ self.assertEqual('add', output_details[0]['name'])
+ self.assertEqual(np.uint8, output_details[0]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
+ self.assertTrue(output_details[0]['quantization'][0] > 0) # scale
+
class FromFlatbufferFile(test_util.TensorFlowTestCase):