aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests/build_defs.bzl
blob: 7fb8e0a26d594a9a0e5b07f676c2c2ce5c1d8c2b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""Build rules for Tensorflow/XLA testing."""

load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured")

def all_backends():
  if cuda_is_configured():
    return ["cpu", "gpu"]
  else:
    return ["cpu"]

def tf_xla_py_test(name, srcs=[], deps=[], tags=[], data=[], main=None,
                   backends=None, **kwargs):
  """Generates py_test targets, one per XLA backend.

  This rule generates py_test() targets named name_backend, for each backend
  in all_backends(). The rule also generates a test suite with named `name` that
  tests all backends for the test.

  For example, the following rule generates test cases foo_test_cpu,
  foo_test_gpu, and a test suite name foo_test that tests both.
  tf_xla_py_test(
      name="foo_test",
      srcs="foo_test.py",
      deps=[...],
  )

  Args:
    name: Name of the target.
    srcs: Sources for the target.
    deps: Dependencies of the target.
    tags: Tags to apply to the generated targets.
    data: Data dependencies of the target.
    main: Same as py_test's main attribute.
    backends: A list of backends to test. Supported values include "cpu" and
      "gpu". If not specified, defaults to all backends.
    **kwargs: keyword arguments passed onto the generated py_test() rules.
  """
  if backends == None:
    backends = all_backends()

  test_names = []
  for backend in backends:
    test_name = "{}_{}".format(name, backend)
    backend_tags = ["tf_xla_{}".format(backend)]
    backend_args = []
    backend_deps = []
    backend_data = []
    if backend == "cpu":
      backend_args += ["--test_device=XLA_CPU",
                       "--types=DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL"]
    elif backend == "gpu":
      backend_args += ["--test_device=XLA_GPU",
                       "--types=DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL"]
      backend_tags += ["requires-gpu-sm35"]
    else:
      fail("Unknown backend {}".format(backend))

    native.py_test(
        name=test_name,
        srcs=srcs,
        srcs_version="PY2AND3",
        args=backend_args,
        main="{}.py".format(name) if main == None else main,
        data=data + backend_data,
        deps=deps + backend_deps,
        tags=tags + backend_tags,
        **kwargs
    )
    test_names.append(test_name)
  native.test_suite(name=name, tests=test_names)

def generate_backend_suites(backends=[]):
  """Generates per-backend test_suites that run all tests for a backend."""
  if not backends:
    backends = all_backends()
  for backend in backends:
    native.test_suite(name="%s_tests" % backend, tags=["tf_xla_%s" % backend])