aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py')
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py19
1 files changed, 12 insertions, 7 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
index d250af9037..09aa30a20b 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
@@ -42,7 +42,7 @@ from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
from tensorflow.python.util import nest
-
+from tensorflow.python.framework import test_util
class Plus1RNNCell(rnn_lib.RNNCell):
"""RNN Cell generating (output, new_state) = (input + 1, state + 1)."""
@@ -2209,9 +2209,10 @@ class TensorArrayOnCorrectDeviceTest(test.TestCase):
return # Test requires access to a GPU
run_metadata = self._execute_rnn_on(
- rnn_device="/cpu:0", cell_device="/gpu:0")
+ rnn_device="/cpu:0", cell_device=test_util.gpu_device_name())
step_stats = run_metadata.step_stats
- ix = 0 if "gpu" in step_stats.dev_stats[0].device else 1
+ ix = 0 if (("gpu" in step_stats.dev_stats[0].device) or
+ ("sycl" in step_stats.dev_stats[0].device)) else 1
gpu_stats = step_stats.dev_stats[ix].node_stats
cpu_stats = step_stats.dev_stats[1 - ix].node_stats
@@ -2233,9 +2234,11 @@ class TensorArrayOnCorrectDeviceTest(test.TestCase):
return # Test requires access to a GPU
run_metadata = self._execute_rnn_on(
- rnn_device="/cpu:0", cell_device="/cpu:0", input_device="/gpu:0")
+ rnn_device="/cpu:0", cell_device="/cpu:0",
+ input_device=test_util.gpu_device_name())
step_stats = run_metadata.step_stats
- ix = 0 if "gpu" in step_stats.dev_stats[0].device else 1
+ ix = 0 if (("gpu" in step_stats.dev_stats[0].device) or
+ ("sycl" in step_stats.dev_stats[0].device)) else 1
gpu_stats = step_stats.dev_stats[ix].node_stats
cpu_stats = step_stats.dev_stats[1 - ix].node_stats
@@ -2250,9 +2253,11 @@ class TensorArrayOnCorrectDeviceTest(test.TestCase):
if not test.is_gpu_available():
return # Test requires access to a GPU
- run_metadata = self._execute_rnn_on(input_device="/gpu:0")
+ run_metadata = self._execute_rnn_on(
+ input_device=test_util.gpu_device_name())
step_stats = run_metadata.step_stats
- ix = 0 if "gpu" in step_stats.dev_stats[0].device else 1
+ ix = 0 if (("gpu" in step_stats.dev_stats[0].device) or
+ ("sycl" in step_stats.dev_stats[0].device)) else 1
gpu_stats = step_stats.dev_stats[ix].node_stats
cpu_stats = step_stats.dev_stats[1 - ix].node_stats