diff options
author | Allen Lavoie <allenl@google.com> | 2018-04-25 16:20:22 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-25 16:22:52 -0700 |
commit | eb31cf8a62739d4df4c84b8edeccbe756b70616d (patch) | |
tree | be9cc72a53afc15637647d10e9182c4fb7fc2005 /tensorflow/contrib/optimizer_v2 | |
parent | 1ab4ea34fca26974afbe078b7b9f8d44a9a58858 (diff) |
Checkpointable: better handling of objects which aren't being restored
initialize_or_restore on a tf.train.Checkpoint status object will now initialize
any variables which aren't being restored, which is closer to the behavior when
executing eagerly (and makes it easier to use).
Fixes a bug where assert_consumed() would miss some Python objects which aren't
part of the object graph being restored. It will now (correctly/as documented)
complain about unmatched Python objects in the dependency graph.
PiperOrigin-RevId: 194315742
Diffstat (limited to 'tensorflow/contrib/optimizer_v2')
-rw-r--r-- | tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py | 6 |
1 files changed, 2 insertions, 4 deletions
diff --git a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py index 8ac9b58145..9e2858d00f 100644 --- a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py +++ b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py @@ -702,8 +702,7 @@ class CheckpointCompatibilityTests(test.TestCase): with save_graph.as_default(), self.test_session( graph=save_graph) as session: root = self._initialized_model() - object_saver = checkpointable_utils.CheckpointableSaver(root) - save_path = object_saver.save( + save_path = root.save( session=session, file_prefix=checkpoint_prefix) with context.eager_mode(): root = self._initialized_model() @@ -716,8 +715,7 @@ class CheckpointCompatibilityTests(test.TestCase): checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") with context.eager_mode(): root = self._initialized_model() - object_saver = checkpointable_utils.CheckpointableSaver(root) - save_path = object_saver.save(file_prefix=checkpoint_prefix) + save_path = root.save(file_prefix=checkpoint_prefix) with context.graph_mode(): save_graph = ops.Graph() with save_graph.as_default(), self.test_session( |