diff options
Diffstat (limited to 'tensorflow/python/training/monitored_session_test.py')
-rw-r--r-- | tensorflow/python/training/monitored_session_test.py | 58 |
1 files changed, 29 insertions, 29 deletions
diff --git a/tensorflow/python/training/monitored_session_test.py b/tensorflow/python/training/monitored_session_test.py index ff586b6c03..2d7799d66a 100644 --- a/tensorflow/python/training/monitored_session_test.py +++ b/tensorflow/python/training/monitored_session_test.py @@ -80,7 +80,7 @@ class ScaffoldTest(test.TestCase): self.assertTrue(isinstance(scaffold.ready_for_local_init_op, ops.Tensor)) self.assertTrue(isinstance(scaffold.local_init_op, ops.Operation)) self.assertTrue(isinstance(scaffold.saver, saver_lib.Saver)) - with self.test_session() as sess: + with self.cached_session() as sess: self.assertItemsEqual([b'my_var', b'my_local_var'], sess.run(scaffold.ready_op)) self.assertItemsEqual([b'my_var'], @@ -513,21 +513,21 @@ class WrappedSessionTest(test.TestCase): """_WrappedSession tests.""" def test_properties(self): - with self.test_session() as sess: + with self.cached_session() as sess: constant_op.constant(0.0) wrapped_sess = monitored_session._WrappedSession(sess) self.assertEquals(sess.graph, wrapped_sess.graph) self.assertEquals(sess.sess_str, wrapped_sess.sess_str) def test_should_stop_on_close(self): - with self.test_session() as sess: + with self.cached_session() as sess: wrapped_sess = monitored_session._WrappedSession(sess) self.assertFalse(wrapped_sess.should_stop()) wrapped_sess.close() self.assertTrue(wrapped_sess.should_stop()) def test_should_stop_uses_check_stop(self): - with self.test_session() as sess: + with self.cached_session() as sess: wrapped_sess = StopAtNSession(sess, 3) self.assertFalse(wrapped_sess.should_stop()) self.assertFalse(wrapped_sess.should_stop()) @@ -535,7 +535,7 @@ class WrappedSessionTest(test.TestCase): self.assertTrue(wrapped_sess.should_stop()) def test_should_stop_delegates_to_wrapped_session(self): - with self.test_session() as sess: + with self.cached_session() as sess: wrapped_sess0 = StopAtNSession(sess, 4) wrapped_sess1 = monitored_session._WrappedSession(wrapped_sess0) self.assertFalse(wrapped_sess1.should_stop()) @@ -545,7 +545,7 @@ class WrappedSessionTest(test.TestCase): self.assertTrue(wrapped_sess1.should_stop()) def test_close_twice(self): - with self.test_session() as sess: + with self.cached_session() as sess: wrapped_sess = monitored_session._WrappedSession(sess) wrapped_sess.close() self.assertTrue(wrapped_sess.should_stop()) @@ -553,7 +553,7 @@ class WrappedSessionTest(test.TestCase): self.assertTrue(wrapped_sess.should_stop()) def test_run(self): - with self.test_session() as sess: + with self.cached_session() as sess: c = constant_op.constant(0) v = array_ops.identity(c) self.assertEqual(42, sess.run(v, feed_dict={c: 42})) @@ -570,7 +570,7 @@ class CoordinatedSessionTest(test.TestCase): """_CoordinatedSession tests.""" def test_properties(self): - with self.test_session() as sess: + with self.cached_session() as sess: constant_op.constant(0.0) coord = coordinator.Coordinator() coord_sess = monitored_session._CoordinatedSession(sess, coord) @@ -578,7 +578,7 @@ class CoordinatedSessionTest(test.TestCase): self.assertEquals(sess.sess_str, coord_sess.sess_str) def test_run(self): - with self.test_session() as sess: + with self.cached_session() as sess: c = constant_op.constant(0) v = array_ops.identity(c) coord = coordinator.Coordinator() @@ -586,7 +586,7 @@ class CoordinatedSessionTest(test.TestCase): self.assertEqual(42, coord_sess.run(v, feed_dict={c: 42})) def test_should_stop_on_close(self): - with self.test_session() as sess: + with self.cached_session() as sess: coord = coordinator.Coordinator() coord_sess = monitored_session._CoordinatedSession(sess, coord) self.assertFalse(coord_sess.should_stop()) @@ -594,7 +594,7 @@ class CoordinatedSessionTest(test.TestCase): self.assertTrue(coord_sess.should_stop()) def test_should_stop_on_coord_stop(self): - with self.test_session() as sess: + with self.cached_session() as sess: coord = coordinator.Coordinator() coord_sess = monitored_session._CoordinatedSession(sess, coord) self.assertFalse(coord_sess.should_stop()) @@ -602,7 +602,7 @@ class CoordinatedSessionTest(test.TestCase): self.assertTrue(coord_sess.should_stop()) def test_dont_request_stop_on_exception_in_main_thread(self): - with self.test_session() as sess: + with self.cached_session() as sess: c = constant_op.constant(0) v = array_ops.identity(c) coord = coordinator.Coordinator() @@ -616,7 +616,7 @@ class CoordinatedSessionTest(test.TestCase): self.assertFalse(coord_sess.should_stop()) def test_stop_threads_on_close_after_exception(self): - with self.test_session() as sess: + with self.cached_session() as sess: c = constant_op.constant(0) v = array_ops.identity(c) coord = coordinator.Coordinator() @@ -646,7 +646,7 @@ class CoordinatedSessionTest(test.TestCase): self.assertTrue(coord_sess.should_stop()) def test_stop_threads_on_close(self): - with self.test_session() as sess: + with self.cached_session() as sess: coord = coordinator.Coordinator() threads = [ threading.Thread( @@ -664,7 +664,7 @@ class CoordinatedSessionTest(test.TestCase): def test_propagates_exception_trace(self): assertion = control_flow_ops.Assert(False, ['This should fail.']) - with self.test_session() as sess: + with self.cached_session() as sess: coord = coordinator.Coordinator(clean_stop_exception_types=()) coord_sess = monitored_session._CoordinatedSession(sess, coord) try: @@ -810,7 +810,7 @@ class RecoverableSessionTest(test.TestCase): return self._sess def test_properties(self): - with self.test_session() as sess: + with self.cached_session() as sess: constant_op.constant(0.0) recoverable_sess = monitored_session._RecoverableSession( self._SessionReturner(sess)) @@ -818,7 +818,7 @@ class RecoverableSessionTest(test.TestCase): self.assertEquals(sess.sess_str, recoverable_sess.sess_str) def test_run(self): - with self.test_session() as sess: + with self.cached_session() as sess: c = constant_op.constant(0) v = array_ops.identity(c) recoverable_sess = monitored_session._RecoverableSession( @@ -826,7 +826,7 @@ class RecoverableSessionTest(test.TestCase): self.assertEqual(51, recoverable_sess.run(v, feed_dict={c: 51})) def test_recovery(self): - with self.test_session() as sess: + with self.cached_session() as sess: class StackSessionCreator(object): @@ -872,7 +872,7 @@ class RecoverableSessionTest(test.TestCase): recoverable_sess.run(v, feed_dict={c: -12}) def test_recovery_from_coordinator_exception(self): - with self.test_session() as test_session: + with self.cached_session() as test_session: session_creator = CountingSessionCreator(test_session) session = monitored_session.MonitoredSession( session_creator, @@ -897,7 +897,7 @@ class RecoverableSessionTest(test.TestCase): self.assertEqual(2, session_creator.number_of_sessions_created) def test_recovery_from_non_preemption_in_coordinator(self): - with self.test_session() as test_session: + with self.cached_session() as test_session: session_creator = CountingSessionCreator(test_session) hook = StopCoordinatorWithException( calls_before_stopping=2, @@ -926,7 +926,7 @@ class RecoverableSessionTest(test.TestCase): session.close() def test_recovery_from_session_getting_stuck(self): - with self.test_session() as test_session: + with self.cached_session() as test_session: session_creator = CountingSessionCreator(test_session) session = monitored_session.MonitoredSession( session_creator, @@ -950,7 +950,7 @@ class RecoverableSessionTest(test.TestCase): self.assertEqual(2, session_creator.number_of_sessions_created) def test_step_fn_recovery_from_coordinator_exception_when_run_hooks(self): - with self.test_session() as test_session: + with self.cached_session() as test_session: session_creator = CountingSessionCreator(test_session) session = monitored_session.MonitoredSession( session_creator, @@ -980,7 +980,7 @@ class RecoverableSessionTest(test.TestCase): self.assertEqual(2, session_creator.number_of_sessions_created) def test_recovery_from_non_preemption_in_coordinator_when_run_hooks(self): - with self.test_session() as test_session: + with self.cached_session() as test_session: session_creator = CountingSessionCreator(test_session) hook = StopCoordinatorWithException( calls_before_stopping=2, @@ -1014,7 +1014,7 @@ class RecoverableSessionTest(test.TestCase): session.close() def test_recovery_from_session_getting_stuck_when_run_hooks(self): - with self.test_session() as test_session: + with self.cached_session() as test_session: session_creator = CountingSessionCreator(test_session) session = monitored_session.MonitoredSession( session_creator, @@ -1058,7 +1058,7 @@ class RecoverableSessionTest(test.TestCase): return session def test_step_fn_recovery_from_coordinator_exception_with_raw_session(self): - with self.test_session() as test_session: + with self.cached_session() as test_session: session_creator = CountingSessionCreator(test_session) session = self.create_raw_session_with_failing_coordinator( session_creator, @@ -1090,7 +1090,7 @@ class RecoverableSessionTest(test.TestCase): self.assertEqual(2, session_creator.number_of_sessions_created) def test_recovery_from_non_preemption_in_coordinator_with_raw_session(self): - with self.test_session() as test_session: + with self.cached_session() as test_session: session_creator = CountingSessionCreator(test_session) session = self.create_raw_session_with_failing_coordinator( session_creator, @@ -1127,7 +1127,7 @@ class RecoverableSessionTest(test.TestCase): session.close() def test_recovery_from_session_getting_stuck_with_raw_session(self): - with self.test_session() as test_session: + with self.cached_session() as test_session: session_creator = CountingSessionCreator(test_session) session = self.create_raw_session_with_failing_coordinator( session_creator, @@ -2047,7 +2047,7 @@ class MonitoredSessionTest(test.TestCase): return value - with self.test_session() as test_session: + with self.cached_session() as test_session: with monitored_session.MonitoredSession( CountingSessionCreator(test_session)) as session: session.run(variables.global_variables_initializer()) @@ -2110,7 +2110,7 @@ class MonitoredSessionTest(test.TestCase): step_context.session.run(graph_side_effect) return step_context.run_with_hooks(fetches=v, feed_dict={c: 1.3}) - with self.test_session() as test_session: + with self.cached_session() as test_session: with monitored_session.MonitoredSession( CountingSessionCreator(test_session), hooks=[Hook(self)]) as session: |