aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tensorflow.bzl
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/tensorflow.bzl')
-rw-r--r--tensorflow/tensorflow.bzl29
1 files changed, 20 insertions, 9 deletions
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index 7fa7e4a91d..0e5b39af10 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -12,6 +12,7 @@ load(
"//tensorflow/core:platform/default/build_config_root.bzl",
"tf_cuda_tests_tags",
"tf_sycl_tests_tags",
+ "tf_additional_xla_deps_py",
)
load(
"@local_config_cuda//cuda:build_defs.bzl",
@@ -789,7 +790,10 @@ def py_test(deps=[], **kwargs):
**kwargs)
def tf_py_test(name, srcs, size="medium", data=[], main=None, args=[],
- tags=[], shard_count=1, additional_deps=[], flaky=0):
+ tags=[], shard_count=1, additional_deps=[], flaky=0,
+ xla_enabled=False):
+ if xla_enabled:
+ additional_deps += tf_additional_xla_deps_py()
native.py_test(
name=name,
size=size,
@@ -811,7 +815,8 @@ def tf_py_test(name, srcs, size="medium", data=[], main=None, args=[],
srcs_version="PY2AND3")
def cuda_py_test(name, srcs, size="medium", data=[], main=None, args=[],
- shard_count=1, additional_deps=[], tags=[], flaky=0):
+ shard_count=1, additional_deps=[], tags=[], flaky=0,
+ xla_enabled=False):
test_tags = tags + tf_cuda_tests_tags()
tf_py_test(name=name,
size=size,
@@ -822,10 +827,12 @@ def cuda_py_test(name, srcs, size="medium", data=[], main=None, args=[],
tags=test_tags,
shard_count=shard_count,
additional_deps=additional_deps,
- flaky=flaky)
+ flaky=flaky,
+ xla_enabled=xla_enabled)
def sycl_py_test(name, srcs, size="medium", data=[], main=None, args=[],
- shard_count=1, additional_deps=[], tags=[], flaky=0):
+ shard_count=1, additional_deps=[], tags=[], flaky=0,
+ xla_enabled=False):
test_tags = tags + tf_sycl_tests_tags()
tf_py_test(name=name,
size=size,
@@ -836,7 +843,8 @@ def sycl_py_test(name, srcs, size="medium", data=[], main=None, args=[],
tags=test_tags,
shard_count=shard_count,
additional_deps=additional_deps,
- flaky=flaky)
+ flaky=flaky,
+ xla_enabled=xla_enabled)
def py_tests(name,
srcs,
@@ -845,7 +853,8 @@ def py_tests(name,
data=[],
tags=[],
shard_count=1,
- prefix=""):
+ prefix="",
+ xla_enabled=False):
for src in srcs:
test_name = src.split("/")[-1].split(".")[0]
if prefix:
@@ -857,13 +866,15 @@ def py_tests(name,
tags=tags,
shard_count=shard_count,
data=data,
- additional_deps=additional_deps)
+ additional_deps=additional_deps,
+ xla_enabled=xla_enabled)
def cuda_py_tests(name, srcs, size="medium", additional_deps=[], data=[],
- shard_count=1, tags=[], prefix=""):
+ shard_count=1, tags=[], prefix="", xla_enabled=False):
test_tags = tags + tf_cuda_tests_tags()
py_tests(name=name, size=size, srcs=srcs, additional_deps=additional_deps,
- data=data, tags=test_tags, shard_count=shard_count,prefix=prefix)
+ data=data, tags=test_tags, shard_count=shard_count,prefix=prefix,
+ xla_enabled=xla_enabled)
# Creates a genrule named <name> for running tools/proto_text's generator to
# make the proto_text functions, for the protos passed in <srcs>.