diff options
author | Nupur Garg <nupurgarg@google.com> | 2018-06-07 12:20:28 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-07 12:23:25 -0700 |
commit | 5c74172fa5bd9f2ae6275d536f70971810a40548 (patch) | |
tree | c20f76a41e543ef228bb6606603731054da55616 /tensorflow/contrib/lite/python/lite_test.py | |
parent | 0dab0f538b78b0a0f1ec4f7dc5fb3005b5efdc94 (diff) |
Add features to TOCO Python API.
PiperOrigin-RevId: 199676295
Diffstat (limited to 'tensorflow/contrib/lite/python/lite_test.py')
-rw-r--r-- | tensorflow/contrib/lite/python/lite_test.py | 34 |
1 files changed, 34 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py index b04caaf263..8c9d2c1651 100644 --- a/tensorflow/contrib/lite/python/lite_test.py +++ b/tensorflow/contrib/lite/python/lite_test.py @@ -220,6 +220,7 @@ class FromSessionTest(test_util.TensorFlowTestCase): self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) self.assertEqual((0., 0.), output_details[0]['quantization']) + # TODO(nupurgarg): Verify value of contents in GraphViz. def testGraphviz(self): in_tensor = array_ops.placeholder( shape=[1, 16, 16, 3], dtype=dtypes.float32) @@ -232,6 +233,39 @@ class FromSessionTest(test_util.TensorFlowTestCase): graphviz_output = converter.convert() self.assertTrue(graphviz_output) + # TODO(nupurgarg): Verify value of contents in GraphViz. + def testDumpGraphviz(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + out_tensor = in_tensor + in_tensor + sess = session.Session() + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) + graphviz_dir = self.get_temp_dir() + converter.dump_graphviz_dir = graphviz_dir + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Ensure interpreter is able to allocate and check graphviz data. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + num_items_graphviz = len(os.listdir(graphviz_dir)) + self.assertTrue(num_items_graphviz) + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) + graphviz_dir = self.get_temp_dir() + converter.dump_graphviz_dir = graphviz_dir + converter.dump_graphviz_video = True + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Ensure graphviz folder has more data after using video flag. + num_items_graphviz_video = len(os.listdir(graphviz_dir)) + self.assertTrue(num_items_graphviz_video > num_items_graphviz) + def testInferenceInputType(self): in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3], dtype=dtypes.uint8) out_tensor = in_tensor + in_tensor |