diff options
-rw-r--r-- | tensorflow/python/training/monitored_session.py | 19 | ||||
-rw-r--r-- | tensorflow/python/training/monitored_session_test.py | 62 | ||||
-rw-r--r-- | tensorflow/tools/api/golden/tensorflow.train.-scaffold.pbtxt | 2 |
3 files changed, 81 insertions, 2 deletions
diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py index a891bae5f2..7f737399ab 100644 --- a/tensorflow/python/training/monitored_session.py +++ b/tensorflow/python/training/monitored_session.py @@ -102,7 +102,8 @@ class Scaffold(object): ready_for_local_init_op=None, local_init_op=None, summary_op=None, - saver=None): + saver=None, + copy_from_scaffold=None): """Create a scaffold. Args: @@ -125,10 +126,26 @@ class Scaffold(object): string tensor containing a serialized `Summary` proto. saver: Optional `tf.train.Saver` object to use to save and restore variables. + copy_from_scaffold: Optional scaffold object to copy fields from. Its + fields will be overwritten by the provided fields in this function. """ + if copy_from_scaffold: + if not isinstance(copy_from_scaffold, Scaffold): + raise TypeError('copy_from_scaffold is not a Scaffold instance.') + init_op = init_op or copy_from_scaffold.init_op + init_feed_dict = init_feed_dict or copy_from_scaffold.init_feed_dict + # Use the original init_fn provided by the user to init the new Scaffold. + init_fn = init_fn or copy_from_scaffold._user_init_fn # pylint: disable=protected-access + ready_op = ready_op or copy_from_scaffold.ready_op + ready_for_local_init_op = ready_for_local_init_op or ( + copy_from_scaffold.ready_for_local_init_op) + local_init_op = local_init_op or copy_from_scaffold.local_init_op + summary_op = summary_op or copy_from_scaffold.summary_op + saver = saver or copy_from_scaffold.saver # NOTE(touts): modifying the init function to be passed the scaffold is a # hack to make it easy to find the saver. Is there a better way? + self._user_init_fn = init_fn if init_fn: self._init_fn = lambda sess: init_fn(self, sess) else: diff --git a/tensorflow/python/training/monitored_session_test.py b/tensorflow/python/training/monitored_session_test.py index 41f8fb3486..85a5ceeb08 100644 --- a/tensorflow/python/training/monitored_session_test.py +++ b/tensorflow/python/training/monitored_session_test.py @@ -147,6 +147,68 @@ class ScaffoldTest(test.TestCase): 'Graph is finalized and cannot be modified'): constant_op.constant([0]) + def test_new_scaffold_from_default_scaffold(self): + scaffold1 = monitored_session.Scaffold() + with ops.Graph().as_default(): + variables.Variable([1]) + saver = saver_lib.Saver() + scaffold2 = monitored_session.Scaffold( + init_op=2, + init_feed_dict=3, + init_fn=lambda scaffold, sess: 4, + ready_op=5, + ready_for_local_init_op=6, + local_init_op=7, + saver=saver, + copy_from_scaffold=scaffold1) + + scaffold2.finalize() + self.assertEqual(2, scaffold2.init_op) + self.assertEqual(3, scaffold2.init_feed_dict) + self.assertTrue(callable(scaffold2.init_fn)) + self.assertEqual(5, scaffold2.ready_op) + self.assertEqual(6, scaffold2.ready_for_local_init_op) + self.assertEqual(7, scaffold2.local_init_op) + self.assertEqual(saver, scaffold2.saver) + + def test_new_scaffold_from_existing_scaffold(self): + with ops.Graph().as_default(): + variables.Variable([1]) + saver = saver_lib.Saver() + scaffold1 = monitored_session.Scaffold( + init_op=2, + init_feed_dict=3, + init_fn=lambda scaffold, sess: 4, + ready_op=5, + ready_for_local_init_op=6, + local_init_op=7, + saver=saver) + + scaffold2 = monitored_session.Scaffold( + init_op=4, + init_feed_dict=6, + init_fn=lambda scaffold, sess: 8, + ready_op=10, + ready_for_local_init_op=12, + local_init_op=14, + saver=saver, + copy_from_scaffold=scaffold1) + + scaffold2.finalize() + self.assertEqual(4, scaffold2.init_op) + self.assertEqual(6, scaffold2.init_feed_dict) + self.assertTrue(callable(scaffold2.init_fn)) + self.assertEqual(10, scaffold2.ready_op) + self.assertEqual(12, scaffold2.ready_for_local_init_op) + self.assertEqual(14, scaffold2.local_init_op) + self.assertEqual(saver, scaffold2.saver) + + def test_copy_from_scaffold_is_scaffold(self): + with ops.Graph().as_default(): + with self.assertRaisesRegexp( + TypeError, 'copy_from_scaffold is not a Scaffold instance'): + monitored_session.Scaffold(copy_from_scaffold=1) + def _test_dir(temp_dir, test_name): """Create an empty dir to use for tests. diff --git a/tensorflow/tools/api/golden/tensorflow.train.-scaffold.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-scaffold.pbtxt index 21234fe739..62b956c5ef 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-scaffold.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-scaffold.pbtxt @@ -36,7 +36,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'init_op\', \'init_feed_dict\', \'init_fn\', \'ready_op\', \'ready_for_local_init_op\', \'local_init_op\', \'summary_op\', \'saver\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'init_op\', \'init_feed_dict\', \'init_fn\', \'ready_op\', \'ready_for_local_init_op\', \'local_init_op\', \'summary_op\', \'saver\', \'copy_from_scaffold\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "finalize" |