aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests/build_defs.bzl
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tests/build_defs.bzl')
-rw-r--r--tensorflow/compiler/tests/build_defs.bzl10
1 files changed, 8 insertions, 2 deletions
diff --git a/tensorflow/compiler/tests/build_defs.bzl b/tensorflow/compiler/tests/build_defs.bzl
index 820db13d0b..0bde616521 100644
--- a/tensorflow/compiler/tests/build_defs.bzl
+++ b/tensorflow/compiler/tests/build_defs.bzl
@@ -1,12 +1,14 @@
"""Build rules for Tensorflow/XLA testing."""
load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured")
+load("//tensorflow/compiler/tests:plugin.bzl", "plugins")
def all_backends():
+ b = ["cpu"] + plugins.keys()
if cuda_is_configured():
- return ["cpu", "gpu"]
+ return b + ["gpu"]
else:
- return ["cpu"]
+ return b
def tf_xla_py_test(name, srcs=[], deps=[], tags=[], data=[], main=None,
disabled_backends=None, **kwargs):
@@ -53,6 +55,10 @@ def tf_xla_py_test(name, srcs=[], deps=[], tags=[], data=[], main=None,
backend_args += ["--test_device=XLA_GPU",
"--types=DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL"]
backend_tags += ["requires-gpu-sm35"]
+ elif backend in plugins:
+ backend_args += ["--test_device=" + plugins[backend]["device"],
+ "--types=" + plugins[backend]["types"]]
+ backend_tags += plugins[backend]["tags"]
else:
fail("Unknown backend {}".format(backend))