aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-24 13:55:07 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-24 13:58:54 -0700
commitcb926e1ed73d6d8f7158cdabf5c4265a921a407b (patch)
tree307b80f49256003777f1e2c874228fea92d55c25 /tensorflow/python/training
parent081d9b7fa17fb9f4ea39b5ac5cc20432ae5d1756 (diff)
Fixes a bug in tf.train.Saver() where it couldn't use Checkpointable
objects in a tf.train.Saver() if var_list was a dict. Includes the logic used for list in the dict code path. PiperOrigin-RevId: 214324913
Diffstat (limited to 'tensorflow/python/training')
-rw-r--r--tensorflow/python/training/saver.py8
-rw-r--r--tensorflow/python/training/saver_test.py32
2 files changed, 25 insertions, 15 deletions
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
index 274c856686..5b2b19e913 100644
--- a/tensorflow/python/training/saver.py
+++ b/tensorflow/python/training/saver.py
@@ -622,6 +622,14 @@ class BaseSaverBuilder(object):
yield BaseSaverBuilder.ResourceVariableSaveable(
variable, variable._save_slice_info.spec, name)
# pylint: enable=protected-access
+ elif isinstance(op, checkpointable.CheckpointableBase) and not isinstance(
+ op, variables.Variable):
+ # pylint: disable=protected-access
+ for attr, factory in op._gather_saveables_for_checkpoint().items():
+ op = (factory(name + "_" + attr) if callable(factory) else factory)
+ for op in BaseSaverBuilder.SaveableObjectsForOp(op, op.name):
+ yield op
+ # pylint: enable=protected-access
else:
# A variable or tensor.
if context.executing_eagerly():
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index 0ac84813c8..69b1055ebe 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -2850,30 +2850,32 @@ class CheckpointableCompatibilityTests(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testNotSaveableButIsCheckpointable(self):
v = _OwnsAVariableSimple()
- saver = saver_module.Saver(var_list=[v])
test_dir = self.get_temp_dir()
prefix = os.path.join(test_dir, "ckpt")
- with self.cached_session() as sess:
- self.evaluate(v.non_dep_variable.assign(42.))
- save_path = saver.save(sess, prefix)
- self.evaluate(v.non_dep_variable.assign(43.))
- saver.restore(sess, save_path)
- self.assertEqual(42., self.evaluate(v.non_dep_variable))
+ for saver in (saver_module.Saver(var_list=[v]),
+ saver_module.Saver(var_list={"v": v})):
+ with self.cached_session() as sess:
+ self.evaluate(v.non_dep_variable.assign(42.))
+ save_path = saver.save(sess, prefix)
+ self.evaluate(v.non_dep_variable.assign(43.))
+ saver.restore(sess, save_path)
+ self.assertEqual(42., self.evaluate(v.non_dep_variable))
@test_util.run_in_graph_and_eager_modes
def testMoreComplexSaveableReturned(self):
v = _OwnsMirroredVariables()
- saver = saver_module.Saver(var_list=[v])
test_dir = self.get_temp_dir()
prefix = os.path.join(test_dir, "ckpt")
self.evaluate(v.non_dep_variable.assign(42.))
- with self.cached_session() as sess:
- save_path = saver.save(sess, prefix)
- self.evaluate(v.non_dep_variable.assign(43.))
- self.evaluate(v.mirrored.assign(44.))
- saver.restore(sess, save_path)
- self.assertEqual(42., self.evaluate(v.non_dep_variable))
- self.assertEqual(42., self.evaluate(v.mirrored))
+ for saver in (saver_module.Saver(var_list=[v]),
+ saver_module.Saver(var_list={"v": v})):
+ with self.cached_session() as sess:
+ save_path = saver.save(sess, prefix)
+ self.evaluate(v.non_dep_variable.assign(43.))
+ self.evaluate(v.mirrored.assign(44.))
+ saver.restore(sess, save_path)
+ self.assertEqual(42., self.evaluate(v.non_dep_variable))
+ self.assertEqual(42., self.evaluate(v.mirrored))
def testSingleTensorEvaluation(self):