aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/framework/python/framework/checkpoint_utils_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/framework/python/framework/checkpoint_utils_test.py')
-rw-r--r--tensorflow/contrib/framework/python/framework/checkpoint_utils_test.py23
1 files changed, 23 insertions, 0 deletions
diff --git a/tensorflow/contrib/framework/python/framework/checkpoint_utils_test.py b/tensorflow/contrib/framework/python/framework/checkpoint_utils_test.py
index 09eecb56dc..51ca5ec125 100644
--- a/tensorflow/contrib/framework/python/framework/checkpoint_utils_test.py
+++ b/tensorflow/contrib/framework/python/framework/checkpoint_utils_test.py
@@ -168,6 +168,29 @@ class CheckpointsTest(test.TestCase):
self.assertAllEqual(my3.eval(session), v3)
self.assertAllEqual(my4.eval(session), v4)
+ def testInitToRootCheckpoint(self):
+ checkpoint_dir = self.get_temp_dir()
+ with self.test_session() as session:
+ v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
+
+ # New graph and session.
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g) as session:
+ my1 = variable_scope.get_variable("var1", [1, 10])
+ my2 = variable_scope.get_variable("var2", [10, 10])
+ my3 = variable_scope.get_variable("var3", [100, 100])
+ with variable_scope.variable_scope("useful_scope"):
+ my4 = variable_scope.get_variable("var4", [9, 9])
+
+ checkpoint_utils.init_from_checkpoint(checkpoint_dir,
+ {"/": "/",})
+
+ session.run(variables.global_variables_initializer())
+ self.assertAllEqual(my1.eval(session), v1)
+ self.assertAllEqual(my2.eval(session), v2)
+ self.assertAllEqual(my3.eval(session), v3)
+ self.assertAllEqual(my4.eval(session), v4)
+
def testInitFromPartitionVar(self):
checkpoint_dir = self.get_temp_dir()
with self.test_session() as session: