diff options
author | 2017-06-16 16:14:33 -0700 | |
---|---|---|
committer | 2017-06-16 16:19:51 -0700 | |
commit | df609c9de4ea0cae0fb1d41893b8071d67bd6bb2 (patch) | |
tree | 0db6e3a90573fd4aa357b23b2461377bc3a1cb56 | |
parent | 226b193e709bd513400f2d74020b02d90cd3a0a0 (diff) |
Allow explict "gpu" in backends without failing if CUDA is not enabled.
PiperOrigin-RevId: 159289583
-rw-r--r-- | tensorflow/compiler/xla/tests/build_defs.bzl | 32 |
1 files changed, 23 insertions, 9 deletions
diff --git a/tensorflow/compiler/xla/tests/build_defs.bzl b/tensorflow/compiler/xla/tests/build_defs.bzl index 1f61743451..50edd8ea5b 100644 --- a/tensorflow/compiler/xla/tests/build_defs.bzl +++ b/tensorflow/compiler/xla/tests/build_defs.bzl @@ -2,11 +2,25 @@ load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured") -def all_backends(): +all_backends = ["cpu", "cpu_parallel", "gpu"] + +def filter_backends(backends): + """Removes "gpu" from a backend list if CUDA is not enabled. + + This allows us to simply hardcode lists including "gpu" here and in the + BUILD file, without causing failures when CUDA isn't enabled.' + + Args: + backends: A list of backends to filter. + + Returns: + The filtered list of backends. + """ if cuda_is_configured(): - return ["cpu", "cpu_parallel", "gpu"] + return backends else: - return ["cpu", "cpu_parallel"] + return [backend for backend in backends if backend != "gpu"] + def xla_test(name, srcs, @@ -81,7 +95,7 @@ def xla_test(name, """ test_names = [] if not backends: - backends = all_backends() + backends = all_backends native.cc_library( name="%s_lib" % name, @@ -91,7 +105,7 @@ def xla_test(name, deps=deps + ["//tensorflow/compiler/xla/tests:test_macros_header"], ) - for backend in backends: + for backend in filter_backends(backends): test_name = "%s_%s" % (name, backend) this_backend_tags = ["xla_%s" % backend] this_backend_copts = [] @@ -127,16 +141,16 @@ def xla_test(name, def generate_backend_suites(backends=[]): if not backends: - backends = all_backends() - for backend in backends: + backends = all_backends + for backend in filter_backends(backends): native.test_suite(name="%s_tests" % backend, tags = ["xla_%s" % backend]) def generate_backend_test_macros(backends=[]): if not backends: - backends = all_backends() - for backend in backends: + backends = all_backends + for backend in filter_backends(backends): native.cc_library( name="test_macros_%s" % backend, testonly = True, |