diff options
Diffstat (limited to 'tensorflow/python/training/saver_test.py')
-rw-r--r-- | tensorflow/python/training/saver_test.py | 16 |
1 files changed, 12 insertions, 4 deletions
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index 03d1c06476..af9f13f438 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -1837,14 +1837,22 @@ class WriteGraphTest(test.TestCase): def testWriteGraph(self): test_dir = _TestDir("write_graph_dir") variables.Variable([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0") - graph_io.write_graph(ops_lib.get_default_graph(), - "/".join([test_dir, "l1"]), "graph.pbtxt") + path = graph_io.write_graph(ops_lib.get_default_graph(), + os.path.join(test_dir, "l1"), "graph.pbtxt") + truth = os.path.join(test_dir, "l1", "graph.pbtxt") + self.assertEqual(path, truth) + self.assertTrue(os.path.exists(path)) + def testRecursiveCreate(self): test_dir = _TestDir("deep_dir") variables.Variable([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0") - graph_io.write_graph(ops_lib.get_default_graph().as_graph_def(), - "/".join([test_dir, "l1/l2/l3"]), "graph.pbtxt") + path = graph_io.write_graph(ops_lib.get_default_graph().as_graph_def(), + os.path.join(test_dir, "l1", "l2", "l3"), + "graph.pbtxt") + truth = os.path.join(test_dir, 'l1', 'l2', 'l3', "graph.pbtxt") + self.assertEqual(path, truth) + self.assertTrue(os.path.exists(path)) class SaverUtilsTest(test.TestCase): |