diff options
author | 2018-09-24 13:55:07 -0700 | |
---|---|---|
committer | 2018-09-24 13:58:54 -0700 | |
commit | cb926e1ed73d6d8f7158cdabf5c4265a921a407b (patch) | |
tree | 307b80f49256003777f1e2c874228fea92d55c25 /tensorflow/python/training | |
parent | 081d9b7fa17fb9f4ea39b5ac5cc20432ae5d1756 (diff) |
Fixes a bug in tf.train.Saver() where it couldn't use Checkpointable
objects in a tf.train.Saver() if var_list was a dict.
Includes the logic used for list in the dict code path.
PiperOrigin-RevId: 214324913
Diffstat (limited to 'tensorflow/python/training')
-rw-r--r-- | tensorflow/python/training/saver.py | 8 | ||||
-rw-r--r-- | tensorflow/python/training/saver_test.py | 32 |
2 files changed, 25 insertions, 15 deletions
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index 274c856686..5b2b19e913 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -622,6 +622,14 @@ class BaseSaverBuilder(object): yield BaseSaverBuilder.ResourceVariableSaveable( variable, variable._save_slice_info.spec, name) # pylint: enable=protected-access + elif isinstance(op, checkpointable.CheckpointableBase) and not isinstance( + op, variables.Variable): + # pylint: disable=protected-access + for attr, factory in op._gather_saveables_for_checkpoint().items(): + op = (factory(name + "_" + attr) if callable(factory) else factory) + for op in BaseSaverBuilder.SaveableObjectsForOp(op, op.name): + yield op + # pylint: enable=protected-access else: # A variable or tensor. if context.executing_eagerly(): diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index 0ac84813c8..69b1055ebe 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -2850,30 +2850,32 @@ class CheckpointableCompatibilityTests(test.TestCase): @test_util.run_in_graph_and_eager_modes def testNotSaveableButIsCheckpointable(self): v = _OwnsAVariableSimple() - saver = saver_module.Saver(var_list=[v]) test_dir = self.get_temp_dir() prefix = os.path.join(test_dir, "ckpt") - with self.cached_session() as sess: - self.evaluate(v.non_dep_variable.assign(42.)) - save_path = saver.save(sess, prefix) - self.evaluate(v.non_dep_variable.assign(43.)) - saver.restore(sess, save_path) - self.assertEqual(42., self.evaluate(v.non_dep_variable)) + for saver in (saver_module.Saver(var_list=[v]), + saver_module.Saver(var_list={"v": v})): + with self.cached_session() as sess: + self.evaluate(v.non_dep_variable.assign(42.)) + save_path = saver.save(sess, prefix) + self.evaluate(v.non_dep_variable.assign(43.)) + saver.restore(sess, save_path) + self.assertEqual(42., self.evaluate(v.non_dep_variable)) @test_util.run_in_graph_and_eager_modes def testMoreComplexSaveableReturned(self): v = _OwnsMirroredVariables() - saver = saver_module.Saver(var_list=[v]) test_dir = self.get_temp_dir() prefix = os.path.join(test_dir, "ckpt") self.evaluate(v.non_dep_variable.assign(42.)) - with self.cached_session() as sess: - save_path = saver.save(sess, prefix) - self.evaluate(v.non_dep_variable.assign(43.)) - self.evaluate(v.mirrored.assign(44.)) - saver.restore(sess, save_path) - self.assertEqual(42., self.evaluate(v.non_dep_variable)) - self.assertEqual(42., self.evaluate(v.mirrored)) + for saver in (saver_module.Saver(var_list=[v]), + saver_module.Saver(var_list={"v": v})): + with self.cached_session() as sess: + save_path = saver.save(sess, prefix) + self.evaluate(v.non_dep_variable.assign(43.)) + self.evaluate(v.mirrored.assign(44.)) + saver.restore(sess, save_path) + self.assertEqual(42., self.evaluate(v.non_dep_variable)) + self.assertEqual(42., self.evaluate(v.mirrored)) def testSingleTensorEvaluation(self): |