diff options
Diffstat (limited to 'tensorflow/compiler/tests/build_defs.bzl')
-rw-r--r-- | tensorflow/compiler/tests/build_defs.bzl | 10 |
1 files changed, 8 insertions, 2 deletions
diff --git a/tensorflow/compiler/tests/build_defs.bzl b/tensorflow/compiler/tests/build_defs.bzl index 820db13d0b..0bde616521 100644 --- a/tensorflow/compiler/tests/build_defs.bzl +++ b/tensorflow/compiler/tests/build_defs.bzl @@ -1,12 +1,14 @@ """Build rules for Tensorflow/XLA testing.""" load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured") +load("//tensorflow/compiler/tests:plugin.bzl", "plugins") def all_backends(): + b = ["cpu"] + plugins.keys() if cuda_is_configured(): - return ["cpu", "gpu"] + return b + ["gpu"] else: - return ["cpu"] + return b def tf_xla_py_test(name, srcs=[], deps=[], tags=[], data=[], main=None, disabled_backends=None, **kwargs): @@ -53,6 +55,10 @@ def tf_xla_py_test(name, srcs=[], deps=[], tags=[], data=[], main=None, backend_args += ["--test_device=XLA_GPU", "--types=DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL"] backend_tags += ["requires-gpu-sm35"] + elif backend in plugins: + backend_args += ["--test_device=" + plugins[backend]["device"], + "--types=" + plugins[backend]["types"]] + backend_tags += plugins[backend]["tags"] else: fail("Unknown backend {}".format(backend)) |