aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/build_defs.bzl
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-06-16 16:14:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-16 16:19:51 -0700
commitdf609c9de4ea0cae0fb1d41893b8071d67bd6bb2 (patch)
tree0db6e3a90573fd4aa357b23b2461377bc3a1cb56 /tensorflow/compiler/xla/tests/build_defs.bzl
parent226b193e709bd513400f2d74020b02d90cd3a0a0 (diff)
Allow explict "gpu" in backends without failing if CUDA is not enabled.
PiperOrigin-RevId: 159289583
Diffstat (limited to 'tensorflow/compiler/xla/tests/build_defs.bzl')
-rw-r--r--tensorflow/compiler/xla/tests/build_defs.bzl32
1 files changed, 23 insertions, 9 deletions
diff --git a/tensorflow/compiler/xla/tests/build_defs.bzl b/tensorflow/compiler/xla/tests/build_defs.bzl
index 1f61743451..50edd8ea5b 100644
--- a/tensorflow/compiler/xla/tests/build_defs.bzl
+++ b/tensorflow/compiler/xla/tests/build_defs.bzl
@@ -2,11 +2,25 @@
load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured")
-def all_backends():
+all_backends = ["cpu", "cpu_parallel", "gpu"]
+
+def filter_backends(backends):
+ """Removes "gpu" from a backend list if CUDA is not enabled.
+
+ This allows us to simply hardcode lists including "gpu" here and in the
+ BUILD file, without causing failures when CUDA isn't enabled.'
+
+ Args:
+ backends: A list of backends to filter.
+
+ Returns:
+ The filtered list of backends.
+ """
if cuda_is_configured():
- return ["cpu", "cpu_parallel", "gpu"]
+ return backends
else:
- return ["cpu", "cpu_parallel"]
+ return [backend for backend in backends if backend != "gpu"]
+
def xla_test(name,
srcs,
@@ -81,7 +95,7 @@ def xla_test(name,
"""
test_names = []
if not backends:
- backends = all_backends()
+ backends = all_backends
native.cc_library(
name="%s_lib" % name,
@@ -91,7 +105,7 @@ def xla_test(name,
deps=deps + ["//tensorflow/compiler/xla/tests:test_macros_header"],
)
- for backend in backends:
+ for backend in filter_backends(backends):
test_name = "%s_%s" % (name, backend)
this_backend_tags = ["xla_%s" % backend]
this_backend_copts = []
@@ -127,16 +141,16 @@ def xla_test(name,
def generate_backend_suites(backends=[]):
if not backends:
- backends = all_backends()
- for backend in backends:
+ backends = all_backends
+ for backend in filter_backends(backends):
native.test_suite(name="%s_tests" % backend,
tags = ["xla_%s" % backend])
def generate_backend_test_macros(backends=[]):
if not backends:
- backends = all_backends()
- for backend in backends:
+ backends = all_backends
+ for backend in filter_backends(backends):
native.cc_library(
name="test_macros_%s" % backend,
testonly = True,