diff options
author | 2017-05-15 13:56:20 -0700 | |
---|---|---|
committer | 2017-05-15 14:00:10 -0700 | |
commit | d319541a809b6f19e04c65a7a186d684b4000ed9 (patch) | |
tree | fcb03f8d94f98257435752ad3a3524186865e512 /tensorflow/python/training/monitored_session.py | |
parent | 5b47c8b6e38fa4b74304d8265c22d53ee385c03b (diff) |
Add copy_from_scaffold parameter in the Scaffold constructor. This allows creating a new Scaffold instance from an exiting one by copying the fields of the original scaffold and replacing those that are provided in the constructor.
PiperOrigin-RevId: 156099660
Diffstat (limited to 'tensorflow/python/training/monitored_session.py')
-rw-r--r-- | tensorflow/python/training/monitored_session.py | 19 |
1 files changed, 18 insertions, 1 deletions
diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py index a891bae5f2..7f737399ab 100644 --- a/tensorflow/python/training/monitored_session.py +++ b/tensorflow/python/training/monitored_session.py @@ -102,7 +102,8 @@ class Scaffold(object): ready_for_local_init_op=None, local_init_op=None, summary_op=None, - saver=None): + saver=None, + copy_from_scaffold=None): """Create a scaffold. Args: @@ -125,10 +126,26 @@ class Scaffold(object): string tensor containing a serialized `Summary` proto. saver: Optional `tf.train.Saver` object to use to save and restore variables. + copy_from_scaffold: Optional scaffold object to copy fields from. Its + fields will be overwritten by the provided fields in this function. """ + if copy_from_scaffold: + if not isinstance(copy_from_scaffold, Scaffold): + raise TypeError('copy_from_scaffold is not a Scaffold instance.') + init_op = init_op or copy_from_scaffold.init_op + init_feed_dict = init_feed_dict or copy_from_scaffold.init_feed_dict + # Use the original init_fn provided by the user to init the new Scaffold. + init_fn = init_fn or copy_from_scaffold._user_init_fn # pylint: disable=protected-access + ready_op = ready_op or copy_from_scaffold.ready_op + ready_for_local_init_op = ready_for_local_init_op or ( + copy_from_scaffold.ready_for_local_init_op) + local_init_op = local_init_op or copy_from_scaffold.local_init_op + summary_op = summary_op or copy_from_scaffold.summary_op + saver = saver or copy_from_scaffold.saver # NOTE(touts): modifying the init function to be passed the scaffold is a # hack to make it easy to find the saver. Is there a better way? + self._user_init_fn = init_fn if init_fn: self._init_fn = lambda sess: init_fn(self, sess) else: |