diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/session_ops_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/session_ops_test.py | 32 |
1 files changed, 16 insertions, 16 deletions
diff --git a/tensorflow/python/kernel_tests/session_ops_test.py b/tensorflow/python/kernel_tests/session_ops_test.py index 678016b13d..03e1ae852f 100644 --- a/tensorflow/python/kernel_tests/session_ops_test.py +++ b/tensorflow/python/kernel_tests/session_ops_test.py @@ -31,7 +31,7 @@ from tensorflow.python.platform import test class SessionOpsTest(test.TestCase): def testHandleBasic(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Return a handle. a = constant_op.constant(10) b = constant_op.constant(5) @@ -45,7 +45,7 @@ class SessionOpsTest(test.TestCase): self.assertEqual(500, sess.run(y, feed_dict={f: h.handle})) def testHandleEval(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Return a handle. a = constant_op.constant(10) b = constant_op.constant(5) @@ -57,7 +57,7 @@ class SessionOpsTest(test.TestCase): self.assertEqual(50, h.eval()) def testHandleAndValue(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Return a handle and a value. a = constant_op.constant(10) b = constant_op.constant(5) @@ -70,7 +70,7 @@ class SessionOpsTest(test.TestCase): self.assertEqual(500, v) def testHandleCond(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Return a handle and a value a = constant_op.constant(10) b = constant_op.constant(5) @@ -90,7 +90,7 @@ class SessionOpsTest(test.TestCase): self.assertEqual(5000, result) def testHandleForLoop(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Initialize a handle. a = constant_op.constant(0) h = session_ops.get_session_handle(a) @@ -107,7 +107,7 @@ class SessionOpsTest(test.TestCase): self.assertEqual(100, h.eval()) def testHandleWhileLoop(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Initialize a handle. a = constant_op.constant(0) h = session_ops.get_session_handle(a) @@ -127,7 +127,7 @@ class SessionOpsTest(test.TestCase): self.assertEqual(101, h.eval()) def testHandleMover(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Return a handle. a = constant_op.constant(10) b = constant_op.constant(5) @@ -148,7 +148,7 @@ class SessionOpsTest(test.TestCase): self.assertEqual(100, sess.run(y, feed_dict={f: h.handle})) def testHandleDelete(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Return a handle. a = constant_op.constant(10) b = constant_op.constant(5) @@ -157,7 +157,7 @@ class SessionOpsTest(test.TestCase): sess.run(h).delete() def testHandleDeleteRaw(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Return a handle. a = constant_op.constant(10) b = constant_op.constant(5) @@ -171,7 +171,7 @@ class SessionOpsTest(test.TestCase): sess.run(x, feed_dict={f: raw_h}) def testMultiDevices(self): - with self.test_session() as sess: + with self.cached_session() as sess: with ops.device(test.gpu_device_name()): a = constant_op.constant(1.0) a_handle = sess.run(session_ops.get_session_handle(a)) @@ -189,7 +189,7 @@ class SessionOpsTest(test.TestCase): self.assertEqual(3.0, c_handle.eval()) def testHandleGC(self): - with self.test_session() as sess: + with self.cached_session() as sess: # initial values live on CPU with ops.device("/cpu:0"): one = constant_op.constant(1, dtype=dtypes.float32) @@ -213,7 +213,7 @@ class SessionOpsTest(test.TestCase): add_h2: x_handle.handle}) def testHandlePlacement(self): - with self.test_session() as sess: + with self.cached_session() as sess: a = constant_op.constant(1.0) a_handle_op = session_ops.get_session_handle(a) b = constant_op.constant(2.0) @@ -233,7 +233,7 @@ class SessionOpsTest(test.TestCase): self.assertEqual(3.0, c_handle.eval()) def testFeedOneHandleDirectly(self): - with self.test_session() as sess: + with self.cached_session() as sess: a = constant_op.constant(10.0) b = constant_op.constant(5.0) c = math_ops.multiply(a, b) @@ -244,7 +244,7 @@ class SessionOpsTest(test.TestCase): self.assertAllClose(2500.0, sess.run(d, feed_dict={c: h_c})) def testDirectHandleFeedOverlappingWithFetches(self): - with self.test_session() as sess: + with self.cached_session() as sess: a = constant_op.constant(10.0) b = constant_op.constant(5.0) c = math_ops.multiply(a, b) @@ -270,7 +270,7 @@ class SessionOpsTest(test.TestCase): self.assertAllClose(50.0, d_val) def testFeedTwoHandlesDirectly(self): - with self.test_session() as sess: + with self.cached_session() as sess: a = constant_op.constant(10.0) b = constant_op.constant(5.0) c = math_ops.multiply(a, b) @@ -284,7 +284,7 @@ class SessionOpsTest(test.TestCase): self.assertAllClose(-48.0, sess.run(e, feed_dict={c: h_d, d: h_c})) def testFeedHandleToVariableDirectly(self): - with self.test_session() as sess: + with self.cached_session() as sess: a = variables.Variable(12.0) inc_a = state_ops.assign_add(a, 2.0) b = math_ops.add(a, 5.0) |