aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/python/lite_test.py
diff options
context:
space:
mode:
authorGravatar Nupur Garg <nupurgarg@google.com>2018-06-07 12:20:28 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-07 12:23:25 -0700
commit5c74172fa5bd9f2ae6275d536f70971810a40548 (patch)
treec20f76a41e543ef228bb6606603731054da55616 /tensorflow/contrib/lite/python/lite_test.py
parent0dab0f538b78b0a0f1ec4f7dc5fb3005b5efdc94 (diff)
Add features to TOCO Python API.
PiperOrigin-RevId: 199676295
Diffstat (limited to 'tensorflow/contrib/lite/python/lite_test.py')
-rw-r--r--tensorflow/contrib/lite/python/lite_test.py34
1 files changed, 34 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py
index b04caaf263..8c9d2c1651 100644
--- a/tensorflow/contrib/lite/python/lite_test.py
+++ b/tensorflow/contrib/lite/python/lite_test.py
@@ -220,6 +220,7 @@ class FromSessionTest(test_util.TensorFlowTestCase):
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
self.assertEqual((0., 0.), output_details[0]['quantization'])
+ # TODO(nupurgarg): Verify value of contents in GraphViz.
def testGraphviz(self):
in_tensor = array_ops.placeholder(
shape=[1, 16, 16, 3], dtype=dtypes.float32)
@@ -232,6 +233,39 @@ class FromSessionTest(test_util.TensorFlowTestCase):
graphviz_output = converter.convert()
self.assertTrue(graphviz_output)
+ # TODO(nupurgarg): Verify value of contents in GraphViz.
+ def testDumpGraphviz(self):
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ out_tensor = in_tensor + in_tensor
+ sess = session.Session()
+
+ # Convert model and ensure model is not None.
+ converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ graphviz_dir = self.get_temp_dir()
+ converter.dump_graphviz_dir = graphviz_dir
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Ensure interpreter is able to allocate and check graphviz data.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ num_items_graphviz = len(os.listdir(graphviz_dir))
+ self.assertTrue(num_items_graphviz)
+
+ # Convert model and ensure model is not None.
+ converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ graphviz_dir = self.get_temp_dir()
+ converter.dump_graphviz_dir = graphviz_dir
+ converter.dump_graphviz_video = True
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Ensure graphviz folder has more data after using video flag.
+ num_items_graphviz_video = len(os.listdir(graphviz_dir))
+ self.assertTrue(num_items_graphviz_video > num_items_graphviz)
+
def testInferenceInputType(self):
in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3], dtype=dtypes.uint8)
out_tensor = in_tensor + in_tensor