aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Jonathan Hseu <jhseu@google.com>2017-01-31 16:14:39 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-31 16:25:22 -0800
commit921e4d2ac45b217ebe055546138cb139fc71ddca (patch)
treefd0e4935ddd13def2eb12daf88347311a2bd89e0
parentd33692e9ccb8974946d4b0142e7b1046ffd05cb1 (diff)
Add a Saver test to ensure it continues working on URIs.
Change: 146179848
-rw-r--r--tensorflow/python/training/saver_test.py20
1 files changed, 20 insertions, 0 deletions
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index d0c73d826a..3072954fbd 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -565,6 +565,26 @@ class SaverTest(test.TestCase):
):
save.save(sess, save_path)
+ def testSaveToURI(self):
+ save_path = "file://" + os.path.join(
+ self.get_temp_dir(), "uri")
+
+ # Build a graph with 2 parameter nodes, and Save and
+ # Restore nodes for them.
+ v0 = variables.Variable(10.0, name="v0")
+ v1 = variables.Variable(20.0, name="v1")
+ save = saver_module.Saver({"v0": v0, "v1": v1}, restore_sequentially=True)
+ init_all_op = variables.global_variables_initializer()
+
+ with self.test_session() as sess:
+ # Initialize all variables
+ sess.run(init_all_op)
+
+ # Check that the parameter nodes have been initialized.
+ self.assertEqual(10.0, v0.eval())
+ self.assertEqual(20.0, v1.eval())
+ save.save(sess, save_path)
+
class SaveRestoreShardedTest(test.TestCase):