aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tensorflow.bzl
diff options
context:
space:
mode:
authorGravatar Austin Anderson <angerson@google.com>2017-12-01 17:09:53 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-01 17:13:56 -0800
commit7248c3ec2c87648fec732e17f3e749d12d113abe (patch)
tree9a182c21000042d9d53b25fb3c201353ad93c9ca /tensorflow/tensorflow.bzl
parente30b0babce133631b19de1fd7bacc84c884d6f55 (diff)
Small reformatting of tensorflow.bzl
PiperOrigin-RevId: 177661127
Diffstat (limited to 'tensorflow/tensorflow.bzl')
-rw-r--r--tensorflow/tensorflow.bzl275
1 files changed, 115 insertions, 160 deletions
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index 709a2d46e1..0015eb0094 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -1,6 +1,5 @@
# -*- Python -*-
-
# Return the options to use for a C++ library or binary build.
# Uses the ":optmode" config_setting to pick the options.
load(
@@ -8,38 +7,35 @@ load(
"tf_cuda_tests_tags",
"tf_sycl_tests_tags",
"tf_additional_xla_deps_py",
- "if_static",)
+ "if_static",
+)
load(
"@local_config_cuda//cuda:build_defs.bzl",
"if_cuda",
- "cuda_default_copts",)
-
+ "cuda_default_copts",
+)
load(
"//third_party/mkl:build_defs.bzl",
- "if_mkl",)
-
+ "if_mkl",
+)
def register_extension_info(**kwargs):
pass
-
# Given a source file, generate a test name.
# i.e. "common_runtime/direct_session_test.cc" becomes
# "common_runtime_direct_session_test"
def src_to_test_name(src):
return src.replace("/", "_").split(".")[0]
-
def full_path(relative_paths):
return [PACKAGE_NAME + "/" + relative for relative in relative_paths]
-
# List of proto files for android builds
def tf_android_core_proto_sources(core_proto_sources_relative):
return [
"//tensorflow/core:" + p for p in core_proto_sources_relative
]
-
# Returns the list of pb.h and proto.h headers that are generated for
# tf_android_core_proto_sources().
def tf_android_core_proto_headers(core_proto_sources_relative):
@@ -51,13 +47,11 @@ def tf_android_core_proto_headers(core_proto_sources_relative):
for p in core_proto_sources_relative
])
-
# Sanitize a dependency so that it works correctly from code that includes
# TensorFlow as a submodule.
def clean_dep(dep):
return str(Label(dep))
-
def if_android_x86(a):
return select({
clean_dep("//tensorflow:android_x86"): a,
@@ -65,35 +59,30 @@ def if_android_x86(a):
"//conditions:default": [],
})
-
def if_android_arm(a):
return select({
clean_dep("//tensorflow:android_arm"): a,
"//conditions:default": [],
})
-
def if_android_arm64(a):
return select({
clean_dep("//tensorflow:android_arm64"): a,
"//conditions:default": [],
})
-
def if_android_mips(a):
return select({
clean_dep("//tensorflow:android_mips"): a,
"//conditions:default": [],
})
-
def if_not_android(a):
return select({
clean_dep("//tensorflow:android"): [],
"//conditions:default": a,
})
-
def if_not_android_mips_and_mips64(a):
return select({
clean_dep("//tensorflow:android_mips"): [],
@@ -101,21 +90,18 @@ def if_not_android_mips_and_mips64(a):
"//conditions:default": a,
})
-
def if_android(a):
return select({
clean_dep("//tensorflow:android"): a,
"//conditions:default": [],
})
-
def if_ios(a):
return select({
clean_dep("//tensorflow:ios"): a,
"//conditions:default": [],
})
-
def if_mobile(a):
return select({
clean_dep("//tensorflow:android"): a,
@@ -123,7 +109,6 @@ def if_mobile(a):
"//conditions:default": [],
})
-
def if_not_mobile(a):
return select({
clean_dep("//tensorflow:android"): [],
@@ -131,7 +116,6 @@ def if_not_mobile(a):
"//conditions:default": a,
})
-
def if_not_windows(a):
return select({
clean_dep("//tensorflow:windows"): [],
@@ -139,7 +123,6 @@ def if_not_windows(a):
"//conditions:default": a,
})
-
def if_linux_x86_64(a):
return select({
clean_dep("//tensorflow:linux_x86_64"): a,
@@ -161,8 +144,10 @@ WIN_COPTS = [
"/DTENSORFLOW_USE_EIGEN_THREADPOOL",
"/DEIGEN_AVOID_STL_ARRAY",
"/Iexternal/gemmlowp",
- "/wd4018", # -Wno-sign-compare
- "/U_HAS_EXCEPTIONS", "/D_HAS_EXCEPTIONS=1", "/EHsc", # -fno-exceptions
+ "/wd4018", # -Wno-sign-compare
+ "/U_HAS_EXCEPTIONS",
+ "/D_HAS_EXCEPTIONS=1",
+ "/EHsc", # -fno-exceptions
"/DNOGDI",
]
@@ -200,7 +185,6 @@ def tf_copts(android_optimization_level_override="-O2"):
"//conditions:default": ["-pthread"]
}))
-
def tf_opts_nortti_if_android():
return if_android([
"-fno-rtti",
@@ -208,10 +192,8 @@ def tf_opts_nortti_if_android():
"-DGOOGLE_PROTOBUF_NO_STATIC_INITIALIZER",
])
-
# LINT.ThenChange(//tensorflow/contrib/android/cmake/CMakeLists.txt)
-
# Given a list of "op_lib_names" (a list of files in the ops directory
# without their .cc extensions), generate a library for that file.
def tf_gen_op_libs(op_lib_names, deps=None):
@@ -229,13 +211,11 @@ def tf_gen_op_libs(op_lib_names, deps=None):
alwayslink=1,
linkstatic=1,)
-
def _make_search_paths(prefix, levels_to_root):
return ",".join(
["-rpath,%s/%s" % (prefix, "/".join([".."] * search_level))
for search_level in range(levels_to_root + 1)])
-
def _rpath_linkopts(name):
# Search parent directories up to the TensorFlow root directory for shared
# object dependencies, even if this op shared object is deeply nested
@@ -254,7 +234,6 @@ def _rpath_linkopts(name):
],
})
-
# Bazel-generated shared objects which must be linked into TensorFlow binaries
# to define symbols from //tensorflow/core:framework and //tensorflow/core:lib.
def tf_binary_additional_srcs():
@@ -264,7 +243,6 @@ def tf_binary_additional_srcs():
clean_dep("//tensorflow:libtensorflow_framework.so"),
])
-
def tf_cc_shared_object(
name,
srcs=[],
@@ -287,9 +265,9 @@ def tf_cc_shared_object(
**kwargs)
register_extension_info(
- extension_name="tf_cc_shared_object",
- label_regex_for_dep="{extension_name}")
-
+ extension_name = "tf_cc_shared_object",
+ label_regex_for_dep = "{extension_name}",
+)
# Links in the framework shared object
# (//third_party/tensorflow:libtensorflow_framework.so) when not building
@@ -312,9 +290,9 @@ def tf_cc_binary(name,
**kwargs)
register_extension_info(
- extension_name="tf_cc_binary",
- label_regex_for_dep="{extension_name}.*")
-
+ extension_name = "tf_cc_binary",
+ label_regex_for_dep = "{extension_name}.*",
+)
def tf_gen_op_wrapper_cc(name,
out_ops_file,
@@ -368,7 +346,6 @@ def tf_gen_op_wrapper_cc(name,
"$(location :" + out_ops_file + ".cc) " + override_arg + " " +
str(include_internal_ops) + " " + api_def_args_str))
-
# Given a list of "op_lib_names" (a list of files in the ops directory
# without their .cc extensions), generate individual C++ .cc and .h
# files for each of the ops files mentioned, and then generate a
@@ -461,7 +438,6 @@ def tf_gen_op_wrappers_cc(name,
alwayslink=1,
visibility=[clean_dep("//tensorflow:internal")])
-
# Generates a Python library target wrapping the ops registered in "deps".
#
# Args:
@@ -554,7 +530,6 @@ def tf_gen_op_wrapper_py(name,
clean_dep("//tensorflow/python:framework_for_generated_wrappers_v2"),
],)
-
# Define a bazel macro that creates cc_test for tensorflow.
#
# Links in the framework shared object
@@ -597,9 +572,9 @@ def tf_cc_test(name,
**kwargs)
register_extension_info(
- extension_name="tf_cc_test",
- label_regex_for_dep="{extension_name}.*")
-
+ extension_name = "tf_cc_test",
+ label_regex_for_dep = "{extension_name}.*",
+)
# Part of the testing workflow requires a distinguishable name for the build
# rules that involve a GPU, even if otherwise identical to the base rule.
@@ -624,9 +599,9 @@ def tf_cc_test_gpu(name,
args=args)
register_extension_info(
- extension_name="tf_cc_test_gpu",
- label_regex_for_dep="{extension_name}")
-
+ extension_name = "tf_cc_test_gpu",
+ label_regex_for_dep = "{extension_name}",
+)
def tf_cuda_cc_test(name,
srcs=[],
@@ -668,9 +643,9 @@ def tf_cuda_cc_test(name,
args=args)
register_extension_info(
- extension_name="tf_cuda_cc_test",
- label_regex_for_dep="{extension_name}")
-
+ extension_name = "tf_cuda_cc_test",
+ label_regex_for_dep = "{extension_name}",
+)
def tf_cuda_only_cc_test(name,
srcs=[],
@@ -702,9 +677,9 @@ def tf_cuda_only_cc_test(name,
tags=tags + tf_cuda_tests_tags())
register_extension_info(
- extension_name="tf_cuda_only_cc_test",
- label_regex_for_dep="{extension_name}_gpu")
-
+ extension_name = "tf_cuda_only_cc_test",
+ label_regex_for_dep = "{extension_name}_gpu",
+)
# Create a cc_test for each of the tensorflow tests listed in "tests"
def tf_cc_tests(srcs,
@@ -728,7 +703,6 @@ def tf_cc_tests(srcs,
linkopts=linkopts,
nocopts=nocopts)
-
def tf_cc_test_mkl(srcs,
deps,
name="",
@@ -738,7 +712,6 @@ def tf_cc_test_mkl(srcs,
args=None):
if_mkl(tf_cc_tests(srcs, deps, name, linkstatic=linkstatic, tags=tags, size=size, args=args, nocopts="-fno-exceptions"))
-
def tf_cc_tests_gpu(srcs,
deps,
name="",
@@ -748,7 +721,6 @@ def tf_cc_tests_gpu(srcs,
args=None):
tf_cc_tests(srcs, deps, linkstatic, tags=tags, size=size, args=args)
-
def tf_cuda_cc_tests(srcs,
deps,
name="",
@@ -781,9 +753,9 @@ def tf_java_test(name,
**kwargs)
register_extension_info(
- extension_name="tf_java_test",
- label_regex_for_dep="{extension_name}")
-
+ extension_name = "tf_java_test",
+ label_regex_for_dep = "{extension_name}",
+)
def _cuda_copts():
"""Gets the appropriate set of copts for (maybe) CUDA compilation.
@@ -803,10 +775,8 @@ def _cuda_copts():
]),
})
-
# Build defs for TensorFlow kernels
-
# When this target is built using --config=cuda, a cc_library is built
# that passes -DGOOGLE_CUDA=1 and '-x cuda', linking in additional
# libraries needed by GPU kernels.
@@ -830,9 +800,9 @@ def tf_gpu_kernel_library(srcs,
**kwargs)
register_extension_info(
- extension_name="tf_gpu_kernel_library",
- label_regex_for_dep="{extension_name}")
-
+ extension_name = "tf_gpu_kernel_library",
+ label_regex_for_dep = "{extension_name}",
+)
def tf_cuda_library(deps=None, cuda_deps=None, copts=None, **kwargs):
"""Generate a cc_library with a conditional set of CUDA dependencies.
@@ -866,10 +836,9 @@ def tf_cuda_library(deps=None, cuda_deps=None, copts=None, **kwargs):
**kwargs)
register_extension_info(
- extension_name="tf_cuda_library",
- label_regex_for_dep="{extension_name}")
-
-
+ extension_name = "tf_cuda_library",
+ label_regex_for_dep = "{extension_name}",
+)
def tf_kernel_library(name,
prefix=None,
@@ -940,9 +909,9 @@ def tf_kernel_library(name,
**kwargs)
register_extension_info(
- extension_name="tf_kernel_library",
- label_regex_for_dep="{extension_name}(_gpu)?")
-
+ extension_name = "tf_kernel_library",
+ label_regex_for_dep = "{extension_name}(_gpu)?",
+)
def tf_mkl_kernel_library(name,
prefix=None,
@@ -981,9 +950,9 @@ def tf_mkl_kernel_library(name,
))
register_extension_info(
- extension_name="tf_mkl_kernel_library",
- label_regex_for_dep="{extension_name}")
-
+ extension_name = "tf_mkl_kernel_library",
+ label_regex_for_dep = "{extension_name}",
+)
# Bazel rules for building swig files.
def _py_wrap_cc_impl(ctx):
@@ -1017,44 +986,41 @@ def _py_wrap_cc_impl(ctx):
progress_message="SWIGing " + src.path)
return struct(files=depset(outputs))
-
_py_wrap_cc = rule(
- attrs={
- "srcs":
- attr.label_list(
- mandatory=True,
- allow_files=True,),
- "swig_includes":
- attr.label_list(
- cfg="data",
- allow_files=True,),
- "deps":
- attr.label_list(
- allow_files=True,
- providers=["cc"],),
- "toolchain_deps":
- attr.label_list(
- allow_files=True,),
- "module_name":
- attr.string(mandatory=True),
- "py_module_name":
- attr.string(mandatory=True),
- "_swig":
- attr.label(
- default=Label("@swig//:swig"),
- executable=True,
- cfg="host",),
- "_swiglib":
- attr.label(
- default=Label("@swig//:templates"),
- allow_files=True,),
+ attrs = {
+ "srcs": attr.label_list(
+ mandatory = True,
+ allow_files = True,
+ ),
+ "swig_includes": attr.label_list(
+ cfg = "data",
+ allow_files = True,
+ ),
+ "deps": attr.label_list(
+ allow_files = True,
+ providers = ["cc"],
+ ),
+ "toolchain_deps": attr.label_list(
+ allow_files = True,
+ ),
+ "module_name": attr.string(mandatory = True),
+ "py_module_name": attr.string(mandatory = True),
+ "_swig": attr.label(
+ default = Label("@swig//:swig"),
+ executable = True,
+ cfg = "host",
+ ),
+ "_swiglib": attr.label(
+ default = Label("@swig//:templates"),
+ allow_files = True,
+ ),
},
- outputs={
+ outputs = {
"cc_out": "%{module_name}.cc",
"py_out": "%{py_module_name}.py",
},
- implementation=_py_wrap_cc_impl,)
-
+ implementation = _py_wrap_cc_impl,
+)
def _get_repository_roots(ctx, files):
"""Returns abnormal root directories under which files reside.
@@ -1085,7 +1051,6 @@ def _get_repository_roots(ctx, files):
result[root] -= 1
return [k for v, k in sorted([(v, k) for k, v in result.items()])]
-
# Bazel rule for collecting the header files that a target depends on.
def _transitive_hdrs_impl(ctx):
outputs = depset()
@@ -1093,21 +1058,20 @@ def _transitive_hdrs_impl(ctx):
outputs += dep.cc.transitive_headers
return struct(files=outputs)
-
_transitive_hdrs = rule(
- attrs={
+ attrs = {
"deps": attr.label_list(
- allow_files=True,
- providers=["cc"],),
+ allow_files = True,
+ providers = ["cc"],
+ ),
},
- implementation=_transitive_hdrs_impl,)
-
+ implementation = _transitive_hdrs_impl,
+)
def transitive_hdrs(name, deps=[], **kwargs):
_transitive_hdrs(name=name + "_gather", deps=deps)
native.filegroup(name=name, srcs=[":" + name + "_gather"])
-
# Create a header only library that includes all the headers exported by
# the libraries in deps.
def cc_header_only_library(name, deps=[], includes=[], **kwargs):
@@ -1133,7 +1097,6 @@ def cc_header_only_library(name, deps=[], includes=[], **kwargs):
includes=includes,
**kwargs)
-
def tf_custom_op_library_additional_deps():
return [
"@protobuf_archive//:protobuf_headers",
@@ -1142,7 +1105,6 @@ def tf_custom_op_library_additional_deps():
clean_dep("//tensorflow/core:framework_headers_lib"),
]
-
# Traverse the dependency graph along the "deps" attribute of the
# target and return a struct with one field called 'tf_collected_deps'.
# tf_collected_deps will be the union of the deps of the current target
@@ -1156,16 +1118,15 @@ def _collect_deps_aspect_impl(target, ctx):
alldeps = alldeps | dep.tf_collected_deps
return struct(tf_collected_deps=alldeps)
-
collect_deps_aspect = aspect(
- implementation=_collect_deps_aspect_impl, attr_aspects=["deps"])
-
+ attr_aspects = ["deps"],
+ implementation = _collect_deps_aspect_impl,
+)
def _dep_label(dep):
label = dep.label
return label.package + ":" + label.name
-
# This rule checks that the transitive dependencies of targets listed
# in the 'deps' attribute don't depend on the targets listed in
# the 'disallowed_deps' attribute.
@@ -1182,18 +1143,20 @@ def _check_deps_impl(ctx):
disallowed_dep))
return struct()
-
check_deps = rule(
_check_deps_impl,
- attrs={
- "deps":
- attr.label_list(
- aspects=[collect_deps_aspect], mandatory=True,
- allow_files=True),
- "disallowed_deps":
- attr.label_list(mandatory=True, allow_files=True)
- },)
-
+ attrs = {
+ "deps": attr.label_list(
+ aspects = [collect_deps_aspect],
+ mandatory = True,
+ allow_files = True,
+ ),
+ "disallowed_deps": attr.label_list(
+ mandatory = True,
+ allow_files = True,
+ ),
+ },
+)
# Helper to build a dynamic library (.so) from the sources containing
# implementations of custom ops and kernels.
@@ -1234,9 +1197,9 @@ def tf_custom_op_library(name, srcs=[], gpu_srcs=[], deps=[], linkopts=[]):
}),)
register_extension_info(
- extension_name="tf_custom_op_library",
- label_regex_for_dep="{extension_name}")
-
+ extension_name = "tf_custom_op_library",
+ label_regex_for_dep = "{extension_name}",
+)
def tf_custom_op_py_library(name,
srcs=[],
@@ -1255,18 +1218,16 @@ def tf_custom_op_py_library(name,
deps=deps,)
register_extension_info(
- extension_name="tf_custom_op_py_library",
- label_regex_for_dep="{extension_name}")
-
+ extension_name = "tf_custom_op_py_library",
+ label_regex_for_dep = "{extension_name}",
+)
def tf_extension_linkopts():
return [] # No extension link opts
-
def tf_extension_copts():
return [] # No extension c opts
-
def tf_py_wrap_cc(name,
srcs,
swig_includes=[],
@@ -1334,7 +1295,6 @@ def tf_py_wrap_cc(name,
"//conditions:default": [":" + cc_library_name],
}))
-
def py_test(deps=[], **kwargs):
native.py_test(
deps=select({
@@ -1344,9 +1304,9 @@ def py_test(deps=[], **kwargs):
**kwargs)
register_extension_info(
- extension_name="py_test",
- label_regex_for_dep="{extension_name}")
-
+ extension_name = "py_test",
+ label_regex_for_dep = "{extension_name}",
+)
def tf_py_test(name,
srcs,
@@ -1382,9 +1342,9 @@ def tf_py_test(name,
srcs_version="PY2AND3")
register_extension_info(
- extension_name="tf_py_test",
- label_regex_map={"additional_deps": "deps:{extension_name}"})
-
+ extension_name = "tf_py_test",
+ label_regex_map = {"additional_deps": "deps:{extension_name}"},
+)
def cuda_py_test(name,
srcs,
@@ -1412,9 +1372,9 @@ def cuda_py_test(name,
xla_enabled=xla_enabled)
register_extension_info(
- extension_name="cuda_py_test",
- label_regex_map={"additional_deps": "additional_deps:{extension_name}"})
-
+ extension_name = "cuda_py_test",
+ label_regex_map = {"additional_deps": "additional_deps:{extension_name}"},
+)
def sycl_py_test(name,
srcs,
@@ -1442,9 +1402,9 @@ def sycl_py_test(name,
xla_enabled=xla_enabled)
register_extension_info(
- extension_name="sycl_py_test",
- label_regex_map={"additional_deps": "additional_deps:{extension_name}"})
-
+ extension_name = "sycl_py_test",
+ label_regex_map = {"additional_deps": "additional_deps:{extension_name}"},
+)
def py_tests(name,
srcs,
@@ -1470,7 +1430,6 @@ def py_tests(name,
additional_deps=additional_deps,
xla_enabled=xla_enabled)
-
def cuda_py_tests(name,
srcs,
size="medium",
@@ -1492,7 +1451,6 @@ def cuda_py_tests(name,
prefix=prefix,
xla_enabled=xla_enabled)
-
# Creates a genrule named <name> for running tools/proto_text's generator to
# make the proto_text functions, for the protos passed in <srcs>.
#
@@ -1515,12 +1473,10 @@ def tf_generate_proto_text_sources(name, srcs_relative_dir, srcs):
],)
return struct(hdrs=out_hdrs, srcs=out_srcs)
-
def tf_genrule_cmd_append_to_srcs(to_append):
return ("cat $(SRCS) > $(@) && " + "echo >> $(@) && " + "echo " + to_append +
" >> $(@)")
-
def tf_version_info_genrule():
native.genrule(
name="version_info_gen",
@@ -1535,7 +1491,6 @@ def tf_version_info_genrule():
local=1,
tools=[clean_dep("//tensorflow/tools/git:gen_git_source.py")],)
-
def tf_py_build_info_genrule():
native.genrule(
name="py_build_info_gen",
@@ -1545,7 +1500,6 @@ def tf_py_build_info_genrule():
local=1,
tools=[clean_dep("//tensorflow/tools/build_info:gen_build_info.py")],)
-
def cc_library_with_android_deps(deps,
android_deps=[],
common_deps=[],
@@ -1554,5 +1508,6 @@ def cc_library_with_android_deps(deps,
native.cc_library(deps=deps, **kwargs)
register_extension_info(
- extension_name="cc_library_with_android_deps",
- label_regex_for_dep="{extension_name}")
+ extension_name = "cc_library_with_android_deps",
+ label_regex_for_dep = "{extension_name}",
+)