aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/monitored_session_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/monitored_session_test.py')
-rw-r--r--tensorflow/python/training/monitored_session_test.py58
1 files changed, 29 insertions, 29 deletions
diff --git a/tensorflow/python/training/monitored_session_test.py b/tensorflow/python/training/monitored_session_test.py
index ff586b6c03..2d7799d66a 100644
--- a/tensorflow/python/training/monitored_session_test.py
+++ b/tensorflow/python/training/monitored_session_test.py
@@ -80,7 +80,7 @@ class ScaffoldTest(test.TestCase):
self.assertTrue(isinstance(scaffold.ready_for_local_init_op, ops.Tensor))
self.assertTrue(isinstance(scaffold.local_init_op, ops.Operation))
self.assertTrue(isinstance(scaffold.saver, saver_lib.Saver))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertItemsEqual([b'my_var', b'my_local_var'],
sess.run(scaffold.ready_op))
self.assertItemsEqual([b'my_var'],
@@ -513,21 +513,21 @@ class WrappedSessionTest(test.TestCase):
"""_WrappedSession tests."""
def test_properties(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
constant_op.constant(0.0)
wrapped_sess = monitored_session._WrappedSession(sess)
self.assertEquals(sess.graph, wrapped_sess.graph)
self.assertEquals(sess.sess_str, wrapped_sess.sess_str)
def test_should_stop_on_close(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
wrapped_sess = monitored_session._WrappedSession(sess)
self.assertFalse(wrapped_sess.should_stop())
wrapped_sess.close()
self.assertTrue(wrapped_sess.should_stop())
def test_should_stop_uses_check_stop(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
wrapped_sess = StopAtNSession(sess, 3)
self.assertFalse(wrapped_sess.should_stop())
self.assertFalse(wrapped_sess.should_stop())
@@ -535,7 +535,7 @@ class WrappedSessionTest(test.TestCase):
self.assertTrue(wrapped_sess.should_stop())
def test_should_stop_delegates_to_wrapped_session(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
wrapped_sess0 = StopAtNSession(sess, 4)
wrapped_sess1 = monitored_session._WrappedSession(wrapped_sess0)
self.assertFalse(wrapped_sess1.should_stop())
@@ -545,7 +545,7 @@ class WrappedSessionTest(test.TestCase):
self.assertTrue(wrapped_sess1.should_stop())
def test_close_twice(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
wrapped_sess = monitored_session._WrappedSession(sess)
wrapped_sess.close()
self.assertTrue(wrapped_sess.should_stop())
@@ -553,7 +553,7 @@ class WrappedSessionTest(test.TestCase):
self.assertTrue(wrapped_sess.should_stop())
def test_run(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
c = constant_op.constant(0)
v = array_ops.identity(c)
self.assertEqual(42, sess.run(v, feed_dict={c: 42}))
@@ -570,7 +570,7 @@ class CoordinatedSessionTest(test.TestCase):
"""_CoordinatedSession tests."""
def test_properties(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
constant_op.constant(0.0)
coord = coordinator.Coordinator()
coord_sess = monitored_session._CoordinatedSession(sess, coord)
@@ -578,7 +578,7 @@ class CoordinatedSessionTest(test.TestCase):
self.assertEquals(sess.sess_str, coord_sess.sess_str)
def test_run(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
c = constant_op.constant(0)
v = array_ops.identity(c)
coord = coordinator.Coordinator()
@@ -586,7 +586,7 @@ class CoordinatedSessionTest(test.TestCase):
self.assertEqual(42, coord_sess.run(v, feed_dict={c: 42}))
def test_should_stop_on_close(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
coord = coordinator.Coordinator()
coord_sess = monitored_session._CoordinatedSession(sess, coord)
self.assertFalse(coord_sess.should_stop())
@@ -594,7 +594,7 @@ class CoordinatedSessionTest(test.TestCase):
self.assertTrue(coord_sess.should_stop())
def test_should_stop_on_coord_stop(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
coord = coordinator.Coordinator()
coord_sess = monitored_session._CoordinatedSession(sess, coord)
self.assertFalse(coord_sess.should_stop())
@@ -602,7 +602,7 @@ class CoordinatedSessionTest(test.TestCase):
self.assertTrue(coord_sess.should_stop())
def test_dont_request_stop_on_exception_in_main_thread(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
c = constant_op.constant(0)
v = array_ops.identity(c)
coord = coordinator.Coordinator()
@@ -616,7 +616,7 @@ class CoordinatedSessionTest(test.TestCase):
self.assertFalse(coord_sess.should_stop())
def test_stop_threads_on_close_after_exception(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
c = constant_op.constant(0)
v = array_ops.identity(c)
coord = coordinator.Coordinator()
@@ -646,7 +646,7 @@ class CoordinatedSessionTest(test.TestCase):
self.assertTrue(coord_sess.should_stop())
def test_stop_threads_on_close(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
coord = coordinator.Coordinator()
threads = [
threading.Thread(
@@ -664,7 +664,7 @@ class CoordinatedSessionTest(test.TestCase):
def test_propagates_exception_trace(self):
assertion = control_flow_ops.Assert(False, ['This should fail.'])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
coord = coordinator.Coordinator(clean_stop_exception_types=())
coord_sess = monitored_session._CoordinatedSession(sess, coord)
try:
@@ -810,7 +810,7 @@ class RecoverableSessionTest(test.TestCase):
return self._sess
def test_properties(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
constant_op.constant(0.0)
recoverable_sess = monitored_session._RecoverableSession(
self._SessionReturner(sess))
@@ -818,7 +818,7 @@ class RecoverableSessionTest(test.TestCase):
self.assertEquals(sess.sess_str, recoverable_sess.sess_str)
def test_run(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
c = constant_op.constant(0)
v = array_ops.identity(c)
recoverable_sess = monitored_session._RecoverableSession(
@@ -826,7 +826,7 @@ class RecoverableSessionTest(test.TestCase):
self.assertEqual(51, recoverable_sess.run(v, feed_dict={c: 51}))
def test_recovery(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
class StackSessionCreator(object):
@@ -872,7 +872,7 @@ class RecoverableSessionTest(test.TestCase):
recoverable_sess.run(v, feed_dict={c: -12})
def test_recovery_from_coordinator_exception(self):
- with self.test_session() as test_session:
+ with self.cached_session() as test_session:
session_creator = CountingSessionCreator(test_session)
session = monitored_session.MonitoredSession(
session_creator,
@@ -897,7 +897,7 @@ class RecoverableSessionTest(test.TestCase):
self.assertEqual(2, session_creator.number_of_sessions_created)
def test_recovery_from_non_preemption_in_coordinator(self):
- with self.test_session() as test_session:
+ with self.cached_session() as test_session:
session_creator = CountingSessionCreator(test_session)
hook = StopCoordinatorWithException(
calls_before_stopping=2,
@@ -926,7 +926,7 @@ class RecoverableSessionTest(test.TestCase):
session.close()
def test_recovery_from_session_getting_stuck(self):
- with self.test_session() as test_session:
+ with self.cached_session() as test_session:
session_creator = CountingSessionCreator(test_session)
session = monitored_session.MonitoredSession(
session_creator,
@@ -950,7 +950,7 @@ class RecoverableSessionTest(test.TestCase):
self.assertEqual(2, session_creator.number_of_sessions_created)
def test_step_fn_recovery_from_coordinator_exception_when_run_hooks(self):
- with self.test_session() as test_session:
+ with self.cached_session() as test_session:
session_creator = CountingSessionCreator(test_session)
session = monitored_session.MonitoredSession(
session_creator,
@@ -980,7 +980,7 @@ class RecoverableSessionTest(test.TestCase):
self.assertEqual(2, session_creator.number_of_sessions_created)
def test_recovery_from_non_preemption_in_coordinator_when_run_hooks(self):
- with self.test_session() as test_session:
+ with self.cached_session() as test_session:
session_creator = CountingSessionCreator(test_session)
hook = StopCoordinatorWithException(
calls_before_stopping=2,
@@ -1014,7 +1014,7 @@ class RecoverableSessionTest(test.TestCase):
session.close()
def test_recovery_from_session_getting_stuck_when_run_hooks(self):
- with self.test_session() as test_session:
+ with self.cached_session() as test_session:
session_creator = CountingSessionCreator(test_session)
session = monitored_session.MonitoredSession(
session_creator,
@@ -1058,7 +1058,7 @@ class RecoverableSessionTest(test.TestCase):
return session
def test_step_fn_recovery_from_coordinator_exception_with_raw_session(self):
- with self.test_session() as test_session:
+ with self.cached_session() as test_session:
session_creator = CountingSessionCreator(test_session)
session = self.create_raw_session_with_failing_coordinator(
session_creator,
@@ -1090,7 +1090,7 @@ class RecoverableSessionTest(test.TestCase):
self.assertEqual(2, session_creator.number_of_sessions_created)
def test_recovery_from_non_preemption_in_coordinator_with_raw_session(self):
- with self.test_session() as test_session:
+ with self.cached_session() as test_session:
session_creator = CountingSessionCreator(test_session)
session = self.create_raw_session_with_failing_coordinator(
session_creator,
@@ -1127,7 +1127,7 @@ class RecoverableSessionTest(test.TestCase):
session.close()
def test_recovery_from_session_getting_stuck_with_raw_session(self):
- with self.test_session() as test_session:
+ with self.cached_session() as test_session:
session_creator = CountingSessionCreator(test_session)
session = self.create_raw_session_with_failing_coordinator(
session_creator,
@@ -2047,7 +2047,7 @@ class MonitoredSessionTest(test.TestCase):
return value
- with self.test_session() as test_session:
+ with self.cached_session() as test_session:
with monitored_session.MonitoredSession(
CountingSessionCreator(test_session)) as session:
session.run(variables.global_variables_initializer())
@@ -2110,7 +2110,7 @@ class MonitoredSessionTest(test.TestCase):
step_context.session.run(graph_side_effect)
return step_context.run_with_hooks(fetches=v, feed_dict={c: 1.3})
- with self.test_session() as test_session:
+ with self.cached_session() as test_session:
with monitored_session.MonitoredSession(
CountingSessionCreator(test_session),
hooks=[Hook(self)]) as session: