aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/monitored_session_test.py
diff options
context:
space:
mode:
authorGravatar Mustafa Ispir <ispir@google.com>2017-10-02 11:18:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-02 11:22:14 -0700
commit45bcc10973f3bbff1f189f8927e568c2f91b3b52 (patch)
treeef1167eca17ee3c83bf1ae3ca0c150c230082c14 /tensorflow/python/training/monitored_session_test.py
parent9bfa43625061ec62bd9623ab014db4851307e92d (diff)
Automated g4 rollback of changelist 170525148
PiperOrigin-RevId: 170726693
Diffstat (limited to 'tensorflow/python/training/monitored_session_test.py')
-rw-r--r--tensorflow/python/training/monitored_session_test.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/tensorflow/python/training/monitored_session_test.py b/tensorflow/python/training/monitored_session_test.py
index 84d262935a..d88b187fde 100644
--- a/tensorflow/python/training/monitored_session_test.py
+++ b/tensorflow/python/training/monitored_session_test.py
@@ -1024,6 +1024,7 @@ class MonitoredSessionTest(test.TestCase):
do_step = state_ops.assign_add(gstep, 1)
# Run till step 3 and save.
hooks = [basic_session_run_hooks.StopAtStepHook(last_step=3)]
+ scaffold = monitored_session.Scaffold().finalize()
with monitored_session.MonitoredSession(hooks=hooks) as session:
self.assertEqual(0, session.run(gstep))
self.assertFalse(session.should_stop())
@@ -1033,9 +1034,8 @@ class MonitoredSessionTest(test.TestCase):
self.assertFalse(session.should_stop())
self.assertEqual(3, session.run(do_step))
self.assertTrue(session.should_stop())
- save_path = saver_lib._get_saver_or_default().save(
- session._coordinated_creator.tf_sess,
- os.path.join(logdir, 'step-3'))
+ save_path = scaffold.saver.save(session._coordinated_creator.tf_sess,
+ os.path.join(logdir, 'step-3'))
# Run till step 5 and save.
def load_ckpt(scaffold, sess):
scaffold.saver.restore(sess, save_path)
@@ -1059,6 +1059,7 @@ class MonitoredSessionTest(test.TestCase):
do_step = state_ops.assign_add(gstep, 1)
# Do 3 steps and save.
hooks = [basic_session_run_hooks.StopAtStepHook(num_steps=3)]
+ scaffold = monitored_session.Scaffold().finalize()
with monitored_session.MonitoredSession(hooks=hooks) as session:
session.run(do_step)
self.assertFalse(session.should_stop())
@@ -1066,9 +1067,8 @@ class MonitoredSessionTest(test.TestCase):
self.assertFalse(session.should_stop())
session.run(do_step)
self.assertTrue(session.should_stop())
- save_path = saver_lib._get_saver_or_default().save(
- session._coordinated_creator.tf_sess,
- os.path.join(logdir, 'step-3'))
+ save_path = scaffold.saver.save(session._coordinated_creator.tf_sess,
+ os.path.join(logdir, 'step-3'))
# Restore and do 4 steps.
def load_ckpt(scaffold, sess):
scaffold.saver.restore(sess, save_path)