aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/monitored_session.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-05-15 13:56:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-15 14:00:10 -0700
commitd319541a809b6f19e04c65a7a186d684b4000ed9 (patch)
treefcb03f8d94f98257435752ad3a3524186865e512 /tensorflow/python/training/monitored_session.py
parent5b47c8b6e38fa4b74304d8265c22d53ee385c03b (diff)
Add copy_from_scaffold parameter in the Scaffold constructor. This allows creating a new Scaffold instance from an exiting one by copying the fields of the original scaffold and replacing those that are provided in the constructor.
PiperOrigin-RevId: 156099660
Diffstat (limited to 'tensorflow/python/training/monitored_session.py')
-rw-r--r--tensorflow/python/training/monitored_session.py19
1 files changed, 18 insertions, 1 deletions
diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py
index a891bae5f2..7f737399ab 100644
--- a/tensorflow/python/training/monitored_session.py
+++ b/tensorflow/python/training/monitored_session.py
@@ -102,7 +102,8 @@ class Scaffold(object):
ready_for_local_init_op=None,
local_init_op=None,
summary_op=None,
- saver=None):
+ saver=None,
+ copy_from_scaffold=None):
"""Create a scaffold.
Args:
@@ -125,10 +126,26 @@ class Scaffold(object):
string tensor containing a serialized `Summary` proto.
saver: Optional `tf.train.Saver` object to use to save and restore
variables.
+ copy_from_scaffold: Optional scaffold object to copy fields from. Its
+ fields will be overwritten by the provided fields in this function.
"""
+ if copy_from_scaffold:
+ if not isinstance(copy_from_scaffold, Scaffold):
+ raise TypeError('copy_from_scaffold is not a Scaffold instance.')
+ init_op = init_op or copy_from_scaffold.init_op
+ init_feed_dict = init_feed_dict or copy_from_scaffold.init_feed_dict
+ # Use the original init_fn provided by the user to init the new Scaffold.
+ init_fn = init_fn or copy_from_scaffold._user_init_fn # pylint: disable=protected-access
+ ready_op = ready_op or copy_from_scaffold.ready_op
+ ready_for_local_init_op = ready_for_local_init_op or (
+ copy_from_scaffold.ready_for_local_init_op)
+ local_init_op = local_init_op or copy_from_scaffold.local_init_op
+ summary_op = summary_op or copy_from_scaffold.summary_op
+ saver = saver or copy_from_scaffold.saver
# NOTE(touts): modifying the init function to be passed the scaffold is a
# hack to make it easy to find the saver. Is there a better way?
+ self._user_init_fn = init_fn
if init_fn:
self._init_fn = lambda sess: init_fn(self, sess)
else: