aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/control_flow_ops_py_test.py')
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py19
1 files changed, 11 insertions, 8 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 54b26057e2..ec681d613f 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -255,7 +255,7 @@ class ControlFlowTest(test.TestCase):
enter_one = control_flow_ops.enter(one, "foo", True)
enter_n = control_flow_ops.enter(n, "foo", True)
- with ops.device("/gpu:0"):
+ with ops.device(test.gpu_device_name()):
merge_i = control_flow_ops.merge([enter_i, enter_i])[0]
less_op = math_ops.less(merge_i, enter_n)
@@ -289,7 +289,7 @@ class ControlFlowTest(test.TestCase):
add_i = math_ops.add(switch_i[1], enter_one)
- with ops.device("/gpu:0"):
+ with ops.device(test.gpu_device_name()):
next_i = control_flow_ops.next_iteration(add_i)
merge_i.op._update_input(1, next_i)
@@ -567,7 +567,7 @@ class ControlFlowTest(test.TestCase):
def testCondRecvIdentity(self):
# Make sure the switch identity is not removed by optimization.
with session.Session(config=opt_cfg()) as sess:
- with ops.device("/gpu:0"):
+ with ops.device(test.gpu_device_name()):
pred = constant_op.constant(True)
def fn1():
@@ -1341,12 +1341,15 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(45, rx.eval())
def _testWhileGrad_ColocateGradients(self, colocate):
+ gpu_dev_name = test.gpu_device_name() if test.is_gpu_available() else "/gpu:0"
+ gpu_short_name = gpu_dev_name.split('/')[-1]
+
with self.test_session(graph=ops.Graph()) as sess:
v = constant_op.constant(2.0, name="v")
c = lambda v: math_ops.less(v, 100.0)
def b(x):
- with ops.device("/gpu:0"):
+ with ops.device(gpu_dev_name):
return math_ops.square(x)
loop = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1)
@@ -1360,12 +1363,12 @@ class ControlFlowTest(test.TestCase):
for (name, dev) in r_devices:
if not colocate and name.endswith("Square"):
# Only forward graph contain gpu in Square device
- self.assertTrue("gpu:0" in dev)
+ self.assertTrue(gpu_short_name in dev)
elif colocate and "Square" in name:
# Forward and backward graphs contain gpu in Square/Square_grad devices
- self.assertTrue("gpu:0" in dev)
+ self.assertTrue(gpu_short_name in dev)
else:
- self.assertFalse("gpu:0" in dev)
+ self.assertFalse(gpu_short_name in dev)
self.assertAllClose(1024.0, sess.run(r))
def testWhileGrad_ColocateGradients(self):
@@ -2566,7 +2569,7 @@ class AssertTest(test.TestCase):
def testGuardedAssertDoesNotCopyWhenTrue(self):
with self.test_session(use_gpu=True) as sess:
- with ops.device("/gpu:0"):
+ with ops.device(test.gpu_device_name()):
value = constant_op.constant(1.0)
with ops.device("/cpu:0"):
true = constant_op.constant(True)