diff options
Diffstat (limited to 'tensorflow/tensorflow.bzl')
-rw-r--r-- | tensorflow/tensorflow.bzl | 29 |
1 files changed, 20 insertions, 9 deletions
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 7fa7e4a91d..0e5b39af10 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -12,6 +12,7 @@ load( "//tensorflow/core:platform/default/build_config_root.bzl", "tf_cuda_tests_tags", "tf_sycl_tests_tags", + "tf_additional_xla_deps_py", ) load( "@local_config_cuda//cuda:build_defs.bzl", @@ -789,7 +790,10 @@ def py_test(deps=[], **kwargs): **kwargs) def tf_py_test(name, srcs, size="medium", data=[], main=None, args=[], - tags=[], shard_count=1, additional_deps=[], flaky=0): + tags=[], shard_count=1, additional_deps=[], flaky=0, + xla_enabled=False): + if xla_enabled: + additional_deps += tf_additional_xla_deps_py() native.py_test( name=name, size=size, @@ -811,7 +815,8 @@ def tf_py_test(name, srcs, size="medium", data=[], main=None, args=[], srcs_version="PY2AND3") def cuda_py_test(name, srcs, size="medium", data=[], main=None, args=[], - shard_count=1, additional_deps=[], tags=[], flaky=0): + shard_count=1, additional_deps=[], tags=[], flaky=0, + xla_enabled=False): test_tags = tags + tf_cuda_tests_tags() tf_py_test(name=name, size=size, @@ -822,10 +827,12 @@ def cuda_py_test(name, srcs, size="medium", data=[], main=None, args=[], tags=test_tags, shard_count=shard_count, additional_deps=additional_deps, - flaky=flaky) + flaky=flaky, + xla_enabled=xla_enabled) def sycl_py_test(name, srcs, size="medium", data=[], main=None, args=[], - shard_count=1, additional_deps=[], tags=[], flaky=0): + shard_count=1, additional_deps=[], tags=[], flaky=0, + xla_enabled=False): test_tags = tags + tf_sycl_tests_tags() tf_py_test(name=name, size=size, @@ -836,7 +843,8 @@ def sycl_py_test(name, srcs, size="medium", data=[], main=None, args=[], tags=test_tags, shard_count=shard_count, additional_deps=additional_deps, - flaky=flaky) + flaky=flaky, + xla_enabled=xla_enabled) def py_tests(name, srcs, @@ -845,7 +853,8 @@ def py_tests(name, data=[], tags=[], shard_count=1, - prefix=""): + prefix="", + xla_enabled=False): for src in srcs: test_name = src.split("/")[-1].split(".")[0] if prefix: @@ -857,13 +866,15 @@ def py_tests(name, tags=tags, shard_count=shard_count, data=data, - additional_deps=additional_deps) + additional_deps=additional_deps, + xla_enabled=xla_enabled) def cuda_py_tests(name, srcs, size="medium", additional_deps=[], data=[], - shard_count=1, tags=[], prefix=""): + shard_count=1, tags=[], prefix="", xla_enabled=False): test_tags = tags + tf_cuda_tests_tags() py_tests(name=name, size=size, srcs=srcs, additional_deps=additional_deps, - data=data, tags=test_tags, shard_count=shard_count,prefix=prefix) + data=data, tags=test_tags, shard_count=shard_count,prefix=prefix, + xla_enabled=xla_enabled) # Creates a genrule named <name> for running tools/proto_text's generator to # make the proto_text functions, for the protos passed in <srcs>. |