diff options
Diffstat (limited to 'tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py')
-rw-r--r-- | tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py index 491d87f62d..3496b355b4 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py @@ -65,7 +65,9 @@ class GatherTreeTest(test.TestCase): _ = beams.eval() def testBadParentValuesOnGPU(self): - if not test.is_gpu_available(): + # Only want to run this test on CUDA devices, as gather_tree is not + # registered for SYCL devices. + if not test.is_gpu_available(cuda_only=True): return # (max_time = 4, batch_size = 1, beams = 3) # bad parent in beam 1 time 1; appears as a negative index at time 0 |