aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/training/monitored_session.py19
-rw-r--r--tensorflow/python/training/monitored_session_test.py62
-rw-r--r--tensorflow/tools/api/golden/tensorflow.train.-scaffold.pbtxt2
3 files changed, 81 insertions, 2 deletions
diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py
index a891bae5f2..7f737399ab 100644
--- a/tensorflow/python/training/monitored_session.py
+++ b/tensorflow/python/training/monitored_session.py
@@ -102,7 +102,8 @@ class Scaffold(object):
ready_for_local_init_op=None,
local_init_op=None,
summary_op=None,
- saver=None):
+ saver=None,
+ copy_from_scaffold=None):
"""Create a scaffold.
Args:
@@ -125,10 +126,26 @@ class Scaffold(object):
string tensor containing a serialized `Summary` proto.
saver: Optional `tf.train.Saver` object to use to save and restore
variables.
+ copy_from_scaffold: Optional scaffold object to copy fields from. Its
+ fields will be overwritten by the provided fields in this function.
"""
+ if copy_from_scaffold:
+ if not isinstance(copy_from_scaffold, Scaffold):
+ raise TypeError('copy_from_scaffold is not a Scaffold instance.')
+ init_op = init_op or copy_from_scaffold.init_op
+ init_feed_dict = init_feed_dict or copy_from_scaffold.init_feed_dict
+ # Use the original init_fn provided by the user to init the new Scaffold.
+ init_fn = init_fn or copy_from_scaffold._user_init_fn # pylint: disable=protected-access
+ ready_op = ready_op or copy_from_scaffold.ready_op
+ ready_for_local_init_op = ready_for_local_init_op or (
+ copy_from_scaffold.ready_for_local_init_op)
+ local_init_op = local_init_op or copy_from_scaffold.local_init_op
+ summary_op = summary_op or copy_from_scaffold.summary_op
+ saver = saver or copy_from_scaffold.saver
# NOTE(touts): modifying the init function to be passed the scaffold is a
# hack to make it easy to find the saver. Is there a better way?
+ self._user_init_fn = init_fn
if init_fn:
self._init_fn = lambda sess: init_fn(self, sess)
else:
diff --git a/tensorflow/python/training/monitored_session_test.py b/tensorflow/python/training/monitored_session_test.py
index 41f8fb3486..85a5ceeb08 100644
--- a/tensorflow/python/training/monitored_session_test.py
+++ b/tensorflow/python/training/monitored_session_test.py
@@ -147,6 +147,68 @@ class ScaffoldTest(test.TestCase):
'Graph is finalized and cannot be modified'):
constant_op.constant([0])
+ def test_new_scaffold_from_default_scaffold(self):
+ scaffold1 = monitored_session.Scaffold()
+ with ops.Graph().as_default():
+ variables.Variable([1])
+ saver = saver_lib.Saver()
+ scaffold2 = monitored_session.Scaffold(
+ init_op=2,
+ init_feed_dict=3,
+ init_fn=lambda scaffold, sess: 4,
+ ready_op=5,
+ ready_for_local_init_op=6,
+ local_init_op=7,
+ saver=saver,
+ copy_from_scaffold=scaffold1)
+
+ scaffold2.finalize()
+ self.assertEqual(2, scaffold2.init_op)
+ self.assertEqual(3, scaffold2.init_feed_dict)
+ self.assertTrue(callable(scaffold2.init_fn))
+ self.assertEqual(5, scaffold2.ready_op)
+ self.assertEqual(6, scaffold2.ready_for_local_init_op)
+ self.assertEqual(7, scaffold2.local_init_op)
+ self.assertEqual(saver, scaffold2.saver)
+
+ def test_new_scaffold_from_existing_scaffold(self):
+ with ops.Graph().as_default():
+ variables.Variable([1])
+ saver = saver_lib.Saver()
+ scaffold1 = monitored_session.Scaffold(
+ init_op=2,
+ init_feed_dict=3,
+ init_fn=lambda scaffold, sess: 4,
+ ready_op=5,
+ ready_for_local_init_op=6,
+ local_init_op=7,
+ saver=saver)
+
+ scaffold2 = monitored_session.Scaffold(
+ init_op=4,
+ init_feed_dict=6,
+ init_fn=lambda scaffold, sess: 8,
+ ready_op=10,
+ ready_for_local_init_op=12,
+ local_init_op=14,
+ saver=saver,
+ copy_from_scaffold=scaffold1)
+
+ scaffold2.finalize()
+ self.assertEqual(4, scaffold2.init_op)
+ self.assertEqual(6, scaffold2.init_feed_dict)
+ self.assertTrue(callable(scaffold2.init_fn))
+ self.assertEqual(10, scaffold2.ready_op)
+ self.assertEqual(12, scaffold2.ready_for_local_init_op)
+ self.assertEqual(14, scaffold2.local_init_op)
+ self.assertEqual(saver, scaffold2.saver)
+
+ def test_copy_from_scaffold_is_scaffold(self):
+ with ops.Graph().as_default():
+ with self.assertRaisesRegexp(
+ TypeError, 'copy_from_scaffold is not a Scaffold instance'):
+ monitored_session.Scaffold(copy_from_scaffold=1)
+
def _test_dir(temp_dir, test_name):
"""Create an empty dir to use for tests.
diff --git a/tensorflow/tools/api/golden/tensorflow.train.-scaffold.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-scaffold.pbtxt
index 21234fe739..62b956c5ef 100644
--- a/tensorflow/tools/api/golden/tensorflow.train.-scaffold.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.train.-scaffold.pbtxt
@@ -36,7 +36,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'init_op\', \'init_feed_dict\', \'init_fn\', \'ready_op\', \'ready_for_local_init_op\', \'local_init_op\', \'summary_op\', \'saver\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'init_op\', \'init_feed_dict\', \'init_fn\', \'ready_op\', \'ready_for_local_init_op\', \'local_init_op\', \'summary_op\', \'saver\', \'copy_from_scaffold\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "finalize"