diff options
Diffstat (limited to 'tensorflow/compiler/tests/build_defs.bzl')
-rw-r--r-- | tensorflow/compiler/tests/build_defs.bzl | 78 |
1 files changed, 78 insertions, 0 deletions
diff --git a/tensorflow/compiler/tests/build_defs.bzl b/tensorflow/compiler/tests/build_defs.bzl new file mode 100644 index 0000000000..7fb8e0a26d --- /dev/null +++ b/tensorflow/compiler/tests/build_defs.bzl @@ -0,0 +1,78 @@ +"""Build rules for Tensorflow/XLA testing.""" + +load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured") + +def all_backends(): + if cuda_is_configured(): + return ["cpu", "gpu"] + else: + return ["cpu"] + +def tf_xla_py_test(name, srcs=[], deps=[], tags=[], data=[], main=None, + backends=None, **kwargs): + """Generates py_test targets, one per XLA backend. + + This rule generates py_test() targets named name_backend, for each backend + in all_backends(). The rule also generates a test suite with named `name` that + tests all backends for the test. + + For example, the following rule generates test cases foo_test_cpu, + foo_test_gpu, and a test suite name foo_test that tests both. + tf_xla_py_test( + name="foo_test", + srcs="foo_test.py", + deps=[...], + ) + + Args: + name: Name of the target. + srcs: Sources for the target. + deps: Dependencies of the target. + tags: Tags to apply to the generated targets. + data: Data dependencies of the target. + main: Same as py_test's main attribute. + backends: A list of backends to test. Supported values include "cpu" and + "gpu". If not specified, defaults to all backends. + **kwargs: keyword arguments passed onto the generated py_test() rules. + """ + if backends == None: + backends = all_backends() + + test_names = [] + for backend in backends: + test_name = "{}_{}".format(name, backend) + backend_tags = ["tf_xla_{}".format(backend)] + backend_args = [] + backend_deps = [] + backend_data = [] + if backend == "cpu": + backend_args += ["--test_device=XLA_CPU", + "--types=DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL"] + elif backend == "gpu": + backend_args += ["--test_device=XLA_GPU", + "--types=DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL"] + backend_tags += ["requires-gpu-sm35"] + else: + fail("Unknown backend {}".format(backend)) + + native.py_test( + name=test_name, + srcs=srcs, + srcs_version="PY2AND3", + args=backend_args, + main="{}.py".format(name) if main == None else main, + data=data + backend_data, + deps=deps + backend_deps, + tags=tags + backend_tags, + **kwargs + ) + test_names.append(test_name) + native.test_suite(name=name, tests=test_names) + +def generate_backend_suites(backends=[]): + """Generates per-backend test_suites that run all tests for a backend.""" + if not backends: + backends = all_backends() + for backend in backends: + native.test_suite(name="%s_tests" % backend, tags=["tf_xla_%s" % backend]) + |