aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/workspace.bzl
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/workspace.bzl')
-rw-r--r--tensorflow/workspace.bzl56
1 files changed, 50 insertions, 6 deletions
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index a13142fe48..f8dfd21f84 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -50,17 +50,54 @@ def _temp_workaround_http_archive_impl(repo_ctx):
}, False)
repo_ctx.download_and_extract(repo_ctx.attr.urls, "", repo_ctx.attr.sha256,
"", repo_ctx.attr.strip_prefix)
+ if repo_ctx.attr.patch_file != None:
+ _apply_patch(repo_ctx, repo_ctx.attr.patch_file)
temp_workaround_http_archive = repository_rule(
implementation=_temp_workaround_http_archive_impl,
attrs = {
"build_file": attr.label(),
"repository": attr.string(),
+ "patch_file": attr.label(default = None),
"urls": attr.string_list(default = []),
"sha256": attr.string(default = ""),
"strip_prefix": attr.string(default = ""),
})
+# Executes specified command with arguments and calls 'fail' if it exited with non-zero code
+def _execute_and_check_ret_code(repo_ctx, cmd_and_args):
+ result = repo_ctx.execute(cmd_and_args)
+ if result.return_code != 0:
+ fail(("Non-zero return code({1}) when executing '{0}':\n" +
+ "Stdout: {2}\n" +
+ "Stderr: {3}").format(" ".join(cmd_and_args),
+ result.return_code, result.stdout, result.stderr))
+
+# Apply a patch_file to the repository root directory
+# Runs 'patch -p1'
+def _apply_patch(repo_ctx, patch_file):
+ _execute_and_check_ret_code(repo_ctx, ["patch", "-p1",
+ "-d", repo_ctx.path("."),
+ "-i", repo_ctx.path(patch_file)])
+
+# Download the repository and apply a patch to its root
+def _patched_http_archive_impl(repo_ctx):
+ repo_ctx.download_and_extract(repo_ctx.attr.urls,
+ sha256 = repo_ctx.attr.sha256,
+ stripPrefix = repo_ctx.attr.strip_prefix)
+ _apply_patch(repo_ctx, repo_ctx.attr.patch_file)
+
+patched_http_archive = repository_rule(
+ implementation = _patched_http_archive_impl,
+ attrs = {
+ "patch_file": attr.label(),
+ "build_file": attr.label(),
+ "repository": attr.string(),
+ "urls": attr.string_list(default = []),
+ "sha256": attr.string(default = ""),
+ "strip_prefix": attr.string(default = ""),
+ })
+
# If TensorFlow is linked as a submodule.
# path_prefix and tf_repo_name are no longer used.
def tf_workspace(path_prefix = "", tf_repo_name = ""):
@@ -78,11 +115,11 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
native.new_http_archive(
name = "eigen_archive",
urls = [
- "http://bazel-mirror.storage.googleapis.com/bitbucket.org/eigen/eigen/get/9c6361787292.tar.gz",
- "https://bitbucket.org/eigen/eigen/get/9c6361787292.tar.gz",
+ "http://bazel-mirror.storage.googleapis.com/bitbucket.org/eigen/eigen/get/deff8b280204.tar.gz",
+ "https://bitbucket.org/eigen/eigen/get/deff8b280204.tar.gz",
],
- sha256 = "e6ec2502a5d82dd5df0b9b16e7697f5fccb81c322d0be8e3492969eecb66badd",
- strip_prefix = "eigen-eigen-9c6361787292",
+ sha256 = "a39834683eb5bdb9a7434f0ab3621d2cbc3b07e8002db6de101e45ec536723eb",
+ strip_prefix = "eigen-eigen-deff8b280204",
build_file = str(Label("//third_party:eigen.BUILD")),
)
@@ -255,7 +292,7 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
actual = "@six_archive//:six",
)
- native.http_archive(
+ patched_http_archive(
name = "protobuf",
urls = [
"http://bazel-mirror.storage.googleapis.com/github.com/google/protobuf/archive/2b7430d96aeff2bb624c8d52182ff5e4b9f7f18a.tar.gz",
@@ -263,6 +300,11 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
],
sha256 = "e5d3d4e227a0f7afb8745df049bbd4d55474b158ca5aaa2a0e31099af24be1d0",
strip_prefix = "protobuf-2b7430d96aeff2bb624c8d52182ff5e4b9f7f18a",
+ # TODO: remove patching when tensorflow stops linking same protos into
+ # multiple shared libraries loaded in runtime by python.
+ # This patch fixes a runtime crash when tensorflow is compiled
+ # with clang -O2 on Linux (see https://github.com/tensorflow/tensorflow/issues/8394)
+ patch_file = str(Label("//third_party/protobuf:add_noinlines.patch")),
)
native.new_http_archive(
@@ -452,7 +494,9 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
],
sha256 = "6787f0eed88d52ee8e32956fa4947d92c139da469f1d8e311c307f27d641118e",
strip_prefix = "nccl-024d1e267845f2ed06f3e2e42476d50f04a00ee6",
- build_file = str(Label("//third_party:nccl.BUILD")),
+ build_file = str(Label("//third_party/nccl:nccl.BUILD")),
+ # TODO: Remove patching after the fix is merged into nccl(see https://github.com/NVIDIA/nccl/pull/78)
+ patch_file = str(Label("//third_party/nccl:fix_clang_compilation.patch")),
repository = tf_repo_name,
)