diff options
Diffstat (limited to 'tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py')
-rw-r--r-- | tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py | 19 |
1 files changed, 12 insertions, 7 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py index d250af9037..09aa30a20b 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py @@ -42,7 +42,7 @@ from tensorflow.python.ops import variables as variables_lib from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging from tensorflow.python.util import nest - +from tensorflow.python.framework import test_util class Plus1RNNCell(rnn_lib.RNNCell): """RNN Cell generating (output, new_state) = (input + 1, state + 1).""" @@ -2209,9 +2209,10 @@ class TensorArrayOnCorrectDeviceTest(test.TestCase): return # Test requires access to a GPU run_metadata = self._execute_rnn_on( - rnn_device="/cpu:0", cell_device="/gpu:0") + rnn_device="/cpu:0", cell_device=test_util.gpu_device_name()) step_stats = run_metadata.step_stats - ix = 0 if "gpu" in step_stats.dev_stats[0].device else 1 + ix = 0 if (("gpu" in step_stats.dev_stats[0].device) or + ("sycl" in step_stats.dev_stats[0].device)) else 1 gpu_stats = step_stats.dev_stats[ix].node_stats cpu_stats = step_stats.dev_stats[1 - ix].node_stats @@ -2233,9 +2234,11 @@ class TensorArrayOnCorrectDeviceTest(test.TestCase): return # Test requires access to a GPU run_metadata = self._execute_rnn_on( - rnn_device="/cpu:0", cell_device="/cpu:0", input_device="/gpu:0") + rnn_device="/cpu:0", cell_device="/cpu:0", + input_device=test_util.gpu_device_name()) step_stats = run_metadata.step_stats - ix = 0 if "gpu" in step_stats.dev_stats[0].device else 1 + ix = 0 if (("gpu" in step_stats.dev_stats[0].device) or + ("sycl" in step_stats.dev_stats[0].device)) else 1 gpu_stats = step_stats.dev_stats[ix].node_stats cpu_stats = step_stats.dev_stats[1 - ix].node_stats @@ -2250,9 +2253,11 @@ class TensorArrayOnCorrectDeviceTest(test.TestCase): if not test.is_gpu_available(): return # Test requires access to a GPU - run_metadata = self._execute_rnn_on(input_device="/gpu:0") + run_metadata = self._execute_rnn_on( + input_device=test_util.gpu_device_name()) step_stats = run_metadata.step_stats - ix = 0 if "gpu" in step_stats.dev_stats[0].device else 1 + ix = 0 if (("gpu" in step_stats.dev_stats[0].device) or + ("sycl" in step_stats.dev_stats[0].device)) else 1 gpu_stats = step_stats.dev_stats[ix].node_stats cpu_stats = step_stats.dev_stats[1 - ix].node_stats |