aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py')
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py
index 491d87f62d..3496b355b4 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py
@@ -65,7 +65,9 @@ class GatherTreeTest(test.TestCase):
_ = beams.eval()
def testBadParentValuesOnGPU(self):
- if not test.is_gpu_available():
+ # Only want to run this test on CUDA devices, as gather_tree is not
+ # registered for SYCL devices.
+ if not test.is_gpu_available(cuda_only=True):
return
# (max_time = 4, batch_size = 1, beams = 3)
# bad parent in beam 1 time 1; appears as a negative index at time 0