aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/saver_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/saver_test.py')
-rw-r--r--tensorflow/python/training/saver_test.py16
1 files changed, 12 insertions, 4 deletions
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index 03d1c06476..af9f13f438 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -1837,14 +1837,22 @@ class WriteGraphTest(test.TestCase):
def testWriteGraph(self):
test_dir = _TestDir("write_graph_dir")
variables.Variable([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")
- graph_io.write_graph(ops_lib.get_default_graph(),
- "/".join([test_dir, "l1"]), "graph.pbtxt")
+ path = graph_io.write_graph(ops_lib.get_default_graph(),
+ os.path.join(test_dir, "l1"), "graph.pbtxt")
+ truth = os.path.join(test_dir, "l1", "graph.pbtxt")
+ self.assertEqual(path, truth)
+ self.assertTrue(os.path.exists(path))
+
def testRecursiveCreate(self):
test_dir = _TestDir("deep_dir")
variables.Variable([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")
- graph_io.write_graph(ops_lib.get_default_graph().as_graph_def(),
- "/".join([test_dir, "l1/l2/l3"]), "graph.pbtxt")
+ path = graph_io.write_graph(ops_lib.get_default_graph().as_graph_def(),
+ os.path.join(test_dir, "l1", "l2", "l3"),
+ "graph.pbtxt")
+ truth = os.path.join(test_dir, 'l1', 'l2', 'l3', "graph.pbtxt")
+ self.assertEqual(path, truth)
+ self.assertTrue(os.path.exists(path))
class SaverUtilsTest(test.TestCase):