aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2017-12-07 11:18:07 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-07 11:22:37 -0800
commitb5c8cd65feb2614e739a83136e3d333b51a6c2f8 (patch)
treea2f9511f302338c56e808e9121c411c42e24862a /tensorflow/python/kernel_tests/control_flow_ops_py_test.py
parente81b739873577dcb828dcb79cc1708eb4b8ae91c (diff)
Fix control flow test to not use session after it's gone out of scope.
This somehow works currently, but breaks with the C API enabled. PiperOrigin-RevId: 178268847
Diffstat (limited to 'tensorflow/python/kernel_tests/control_flow_ops_py_test.py')
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py10
1 files changed, 7 insertions, 3 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index ad02a9e58c..20eb923e72 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -1450,7 +1450,8 @@ class ControlFlowTest(test.TestCase):
gpu_dev_name = test.gpu_device_name() if test.is_gpu_available(
) else "/device:GPU:0"
- with self.test_session(graph=ops.Graph()) as sess:
+ graph = ops.Graph()
+ with graph.as_default():
v = constant_op.constant(2.0, name="v")
c = lambda v: math_ops.less(v, 100.0)
@@ -1461,7 +1462,8 @@ class ControlFlowTest(test.TestCase):
loop = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1)
r = gradients_impl.gradients(
loop, v, colocate_gradients_with_ops=colocate)[0]
- r_ops = r.graph.get_operations()
+
+ r_ops = graph.get_operations()
r_devices = [(op.name, op.device) for op in r_ops]
self.assertTrue(any("Square" in op.name for op in r_ops))
@@ -1475,7 +1477,9 @@ class ControlFlowTest(test.TestCase):
self.assertTrue(gpu_dev_name in dev)
else:
self.assertFalse(gpu_dev_name in dev)
- self.assertAllClose(1024.0, sess.run(r))
+
+ with self.test_session(graph=graph) as sess:
+ self.assertAllClose(1024.0, sess.run(r))
def testWhileGrad_ColocateGradients(self):
self._testWhileGrad_ColocateGradients(colocate=False)