aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/optimizer_v2
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-04-25 16:20:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-25 16:22:52 -0700
commiteb31cf8a62739d4df4c84b8edeccbe756b70616d (patch)
treebe9cc72a53afc15637647d10e9182c4fb7fc2005 /tensorflow/contrib/optimizer_v2
parent1ab4ea34fca26974afbe078b7b9f8d44a9a58858 (diff)
Checkpointable: better handling of objects which aren't being restored
initialize_or_restore on a tf.train.Checkpoint status object will now initialize any variables which aren't being restored, which is closer to the behavior when executing eagerly (and makes it easier to use). Fixes a bug where assert_consumed() would miss some Python objects which aren't part of the object graph being restored. It will now (correctly/as documented) complain about unmatched Python objects in the dependency graph. PiperOrigin-RevId: 194315742
Diffstat (limited to 'tensorflow/contrib/optimizer_v2')
-rw-r--r--tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py6
1 files changed, 2 insertions, 4 deletions
diff --git a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
index 8ac9b58145..9e2858d00f 100644
--- a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
+++ b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
@@ -702,8 +702,7 @@ class CheckpointCompatibilityTests(test.TestCase):
with save_graph.as_default(), self.test_session(
graph=save_graph) as session:
root = self._initialized_model()
- object_saver = checkpointable_utils.CheckpointableSaver(root)
- save_path = object_saver.save(
+ save_path = root.save(
session=session, file_prefix=checkpoint_prefix)
with context.eager_mode():
root = self._initialized_model()
@@ -716,8 +715,7 @@ class CheckpointCompatibilityTests(test.TestCase):
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
with context.eager_mode():
root = self._initialized_model()
- object_saver = checkpointable_utils.CheckpointableSaver(root)
- save_path = object_saver.save(file_prefix=checkpoint_prefix)
+ save_path = root.save(file_prefix=checkpoint_prefix)
with context.graph_mode():
save_graph = ops.Graph()
with save_graph.as_default(), self.test_session(