aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-01-27 17:37:30 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-27 17:46:06 -0800
commitce2a102760ad7c3503db484a4c0cdd97b5543dfc (patch)
tree937d61ee560bb9b7a8798106fe648b599088521d
parent9d7899f99267727e65e0a8183f30cab24bef5536 (diff)
Add xla target for tensorflow tests that request them.
Change: 145856327
-rwxr-xr-xconfigure4
-rw-r--r--tensorflow/core/platform/default/build_config.bzl13
-rw-r--r--tensorflow/core/platform/default/build_config_root.bzl17
-rw-r--r--tensorflow/python/BUILD2
-rw-r--r--tensorflow/tensorflow.bzl29
-rw-r--r--tensorflow/tools/pip_package/BUILD2
6 files changed, 41 insertions, 26 deletions
diff --git a/configure b/configure
index 6cd5c2f3db..2d8d85b021 100755
--- a/configure
+++ b/configure
@@ -175,10 +175,10 @@ done
if [ "$TF_ENABLE_XLA" == "1" ]; then
# Update Bazel build configuration.
- perl -pi -e "s,WITH_XLA_SUPPORT = (False|True),WITH_XLA_SUPPORT = True,s" tensorflow/core/platform/default/build_config.bzl
+ sed -i -e "s/^WITH_XLA_SUPPORT = [FT].*/WITH_XLA_SUPPORT = True/" tensorflow/core/platform/default/build_config_root.bzl
else
# Update Bazel build configuration.
- perl -pi -e "s,WITH_XLA_SUPPORT = (False|True),WITH_XLA_SUPPORT = False,s" tensorflow/core/platform/default/build_config.bzl
+ sed -i -e "s/^WITH_XLA_SUPPORT = [FT].*/WITH_XLA_SUPPORT = False/" tensorflow/core/platform/default/build_config_root.bzl
fi
diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
index ebf835d110..56d4f6ff58 100644
--- a/tensorflow/core/platform/default/build_config.bzl
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -7,7 +7,6 @@ load("//tensorflow:tensorflow.bzl", "if_not_mobile")
# configure may change the following lines
WITH_GCP_SUPPORT = False
WITH_HDFS_SUPPORT = False
-WITH_XLA_SUPPORT = False
WITH_JEMALLOC = True
# Appends a suffix to a list of deps.
@@ -242,15 +241,3 @@ def tf_additional_cloud_kernel_deps():
#if WITH_GCP_SUPPORT:
# deps = if_not_mobile(["//tensorflow/core:cloud_ops_op_lib"])
return deps
-
-def tf_additional_plugin_deps():
- deps = []
- if WITH_XLA_SUPPORT:
- deps.append("//tensorflow/compiler/jit")
- return deps
-
-def tf_additional_license_deps():
- licenses = []
- if WITH_XLA_SUPPORT:
- licenses.append("@llvm//:LICENSE.TXT")
- return licenses
diff --git a/tensorflow/core/platform/default/build_config_root.bzl b/tensorflow/core/platform/default/build_config_root.bzl
index 2fa2726bde..23a7b9065a 100644
--- a/tensorflow/core/platform/default/build_config_root.bzl
+++ b/tensorflow/core/platform/default/build_config_root.bzl
@@ -2,8 +2,25 @@
# The functions in this file might be referred by tensorflow.bzl. They have to
# be separate to avoid cyclic references.
+WITH_XLA_SUPPORT = False
+
def tf_cuda_tests_tags():
return ["local"]
def tf_sycl_tests_tags():
return ["local"]
+
+def tf_additional_plugin_deps():
+ deps = []
+ if WITH_XLA_SUPPORT:
+ deps.append("//tensorflow/compiler/jit")
+ return deps
+
+def tf_additional_xla_deps_py():
+ return []
+
+def tf_additional_license_deps():
+ licenses = []
+ if WITH_XLA_SUPPORT:
+ licenses.append("@llvm//:LICENSE.TXT")
+ return licenses
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 1834ce570e..2befe43be6 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -23,7 +23,7 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library_py")
load("//tensorflow/core:platform/default/build_config.bzl", "tf_additional_lib_deps")
-load("//tensorflow/core:platform/default/build_config.bzl", "tf_additional_plugin_deps")
+load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_plugin_deps")
load("//tensorflow/python:build_defs.bzl", "tf_gen_op_wrapper_private_py")
py_library(
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index 7fa7e4a91d..0e5b39af10 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -12,6 +12,7 @@ load(
"//tensorflow/core:platform/default/build_config_root.bzl",
"tf_cuda_tests_tags",
"tf_sycl_tests_tags",
+ "tf_additional_xla_deps_py",
)
load(
"@local_config_cuda//cuda:build_defs.bzl",
@@ -789,7 +790,10 @@ def py_test(deps=[], **kwargs):
**kwargs)
def tf_py_test(name, srcs, size="medium", data=[], main=None, args=[],
- tags=[], shard_count=1, additional_deps=[], flaky=0):
+ tags=[], shard_count=1, additional_deps=[], flaky=0,
+ xla_enabled=False):
+ if xla_enabled:
+ additional_deps += tf_additional_xla_deps_py()
native.py_test(
name=name,
size=size,
@@ -811,7 +815,8 @@ def tf_py_test(name, srcs, size="medium", data=[], main=None, args=[],
srcs_version="PY2AND3")
def cuda_py_test(name, srcs, size="medium", data=[], main=None, args=[],
- shard_count=1, additional_deps=[], tags=[], flaky=0):
+ shard_count=1, additional_deps=[], tags=[], flaky=0,
+ xla_enabled=False):
test_tags = tags + tf_cuda_tests_tags()
tf_py_test(name=name,
size=size,
@@ -822,10 +827,12 @@ def cuda_py_test(name, srcs, size="medium", data=[], main=None, args=[],
tags=test_tags,
shard_count=shard_count,
additional_deps=additional_deps,
- flaky=flaky)
+ flaky=flaky,
+ xla_enabled=xla_enabled)
def sycl_py_test(name, srcs, size="medium", data=[], main=None, args=[],
- shard_count=1, additional_deps=[], tags=[], flaky=0):
+ shard_count=1, additional_deps=[], tags=[], flaky=0,
+ xla_enabled=False):
test_tags = tags + tf_sycl_tests_tags()
tf_py_test(name=name,
size=size,
@@ -836,7 +843,8 @@ def sycl_py_test(name, srcs, size="medium", data=[], main=None, args=[],
tags=test_tags,
shard_count=shard_count,
additional_deps=additional_deps,
- flaky=flaky)
+ flaky=flaky,
+ xla_enabled=xla_enabled)
def py_tests(name,
srcs,
@@ -845,7 +853,8 @@ def py_tests(name,
data=[],
tags=[],
shard_count=1,
- prefix=""):
+ prefix="",
+ xla_enabled=False):
for src in srcs:
test_name = src.split("/")[-1].split(".")[0]
if prefix:
@@ -857,13 +866,15 @@ def py_tests(name,
tags=tags,
shard_count=shard_count,
data=data,
- additional_deps=additional_deps)
+ additional_deps=additional_deps,
+ xla_enabled=xla_enabled)
def cuda_py_tests(name, srcs, size="medium", additional_deps=[], data=[],
- shard_count=1, tags=[], prefix=""):
+ shard_count=1, tags=[], prefix="", xla_enabled=False):
test_tags = tags + tf_cuda_tests_tags()
py_tests(name=name, size=size, srcs=srcs, additional_deps=additional_deps,
- data=data, tags=test_tags, shard_count=shard_count,prefix=prefix)
+ data=data, tags=test_tags, shard_count=shard_count,prefix=prefix,
+ xla_enabled=xla_enabled)
# Creates a genrule named <name> for running tools/proto_text's generator to
# make the proto_text functions, for the protos passed in <srcs>.
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index 0ffbec8b3c..85a8b79f85 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -4,7 +4,7 @@
package(default_visibility = ["//visibility:private"])
load("//tensorflow:tensorflow.bzl", "transitive_hdrs")
-load("//tensorflow/core:platform/default/build_config.bzl", "tf_additional_license_deps")
+load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_license_deps")
# This returns a list of headers of all public header libraries (e.g.,
# framework, lib), and all of the transitive dependencies of those