diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/control_flow_ops_py_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/control_flow_ops_py_test.py | 19 |
1 files changed, 11 insertions, 8 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 54b26057e2..ec681d613f 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -255,7 +255,7 @@ class ControlFlowTest(test.TestCase): enter_one = control_flow_ops.enter(one, "foo", True) enter_n = control_flow_ops.enter(n, "foo", True) - with ops.device("/gpu:0"): + with ops.device(test.gpu_device_name()): merge_i = control_flow_ops.merge([enter_i, enter_i])[0] less_op = math_ops.less(merge_i, enter_n) @@ -289,7 +289,7 @@ class ControlFlowTest(test.TestCase): add_i = math_ops.add(switch_i[1], enter_one) - with ops.device("/gpu:0"): + with ops.device(test.gpu_device_name()): next_i = control_flow_ops.next_iteration(add_i) merge_i.op._update_input(1, next_i) @@ -567,7 +567,7 @@ class ControlFlowTest(test.TestCase): def testCondRecvIdentity(self): # Make sure the switch identity is not removed by optimization. with session.Session(config=opt_cfg()) as sess: - with ops.device("/gpu:0"): + with ops.device(test.gpu_device_name()): pred = constant_op.constant(True) def fn1(): @@ -1341,12 +1341,15 @@ class ControlFlowTest(test.TestCase): self.assertEqual(45, rx.eval()) def _testWhileGrad_ColocateGradients(self, colocate): + gpu_dev_name = test.gpu_device_name() if test.is_gpu_available() else "/gpu:0" + gpu_short_name = gpu_dev_name.split('/')[-1] + with self.test_session(graph=ops.Graph()) as sess: v = constant_op.constant(2.0, name="v") c = lambda v: math_ops.less(v, 100.0) def b(x): - with ops.device("/gpu:0"): + with ops.device(gpu_dev_name): return math_ops.square(x) loop = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1) @@ -1360,12 +1363,12 @@ class ControlFlowTest(test.TestCase): for (name, dev) in r_devices: if not colocate and name.endswith("Square"): # Only forward graph contain gpu in Square device - self.assertTrue("gpu:0" in dev) + self.assertTrue(gpu_short_name in dev) elif colocate and "Square" in name: # Forward and backward graphs contain gpu in Square/Square_grad devices - self.assertTrue("gpu:0" in dev) + self.assertTrue(gpu_short_name in dev) else: - self.assertFalse("gpu:0" in dev) + self.assertFalse(gpu_short_name in dev) self.assertAllClose(1024.0, sess.run(r)) def testWhileGrad_ColocateGradients(self): @@ -2566,7 +2569,7 @@ class AssertTest(test.TestCase): def testGuardedAssertDoesNotCopyWhenTrue(self): with self.test_session(use_gpu=True) as sess: - with ops.device("/gpu:0"): + with ops.device(test.gpu_device_name()): value = constant_op.constant(1.0) with ops.device("/cpu:0"): true = constant_op.constant(True) |