diff options
Diffstat (limited to 'tensorflow/workspace.bzl')
-rw-r--r-- | tensorflow/workspace.bzl | 56 |
1 files changed, 50 insertions, 6 deletions
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index a13142fe48..f8dfd21f84 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -50,17 +50,54 @@ def _temp_workaround_http_archive_impl(repo_ctx): }, False) repo_ctx.download_and_extract(repo_ctx.attr.urls, "", repo_ctx.attr.sha256, "", repo_ctx.attr.strip_prefix) + if repo_ctx.attr.patch_file != None: + _apply_patch(repo_ctx, repo_ctx.attr.patch_file) temp_workaround_http_archive = repository_rule( implementation=_temp_workaround_http_archive_impl, attrs = { "build_file": attr.label(), "repository": attr.string(), + "patch_file": attr.label(default = None), "urls": attr.string_list(default = []), "sha256": attr.string(default = ""), "strip_prefix": attr.string(default = ""), }) +# Executes specified command with arguments and calls 'fail' if it exited with non-zero code +def _execute_and_check_ret_code(repo_ctx, cmd_and_args): + result = repo_ctx.execute(cmd_and_args) + if result.return_code != 0: + fail(("Non-zero return code({1}) when executing '{0}':\n" + + "Stdout: {2}\n" + + "Stderr: {3}").format(" ".join(cmd_and_args), + result.return_code, result.stdout, result.stderr)) + +# Apply a patch_file to the repository root directory +# Runs 'patch -p1' +def _apply_patch(repo_ctx, patch_file): + _execute_and_check_ret_code(repo_ctx, ["patch", "-p1", + "-d", repo_ctx.path("."), + "-i", repo_ctx.path(patch_file)]) + +# Download the repository and apply a patch to its root +def _patched_http_archive_impl(repo_ctx): + repo_ctx.download_and_extract(repo_ctx.attr.urls, + sha256 = repo_ctx.attr.sha256, + stripPrefix = repo_ctx.attr.strip_prefix) + _apply_patch(repo_ctx, repo_ctx.attr.patch_file) + +patched_http_archive = repository_rule( + implementation = _patched_http_archive_impl, + attrs = { + "patch_file": attr.label(), + "build_file": attr.label(), + "repository": attr.string(), + "urls": attr.string_list(default = []), + "sha256": attr.string(default = ""), + "strip_prefix": attr.string(default = ""), + }) + # If TensorFlow is linked as a submodule. # path_prefix and tf_repo_name are no longer used. def tf_workspace(path_prefix = "", tf_repo_name = ""): @@ -78,11 +115,11 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""): native.new_http_archive( name = "eigen_archive", urls = [ - "http://bazel-mirror.storage.googleapis.com/bitbucket.org/eigen/eigen/get/9c6361787292.tar.gz", - "https://bitbucket.org/eigen/eigen/get/9c6361787292.tar.gz", + "http://bazel-mirror.storage.googleapis.com/bitbucket.org/eigen/eigen/get/deff8b280204.tar.gz", + "https://bitbucket.org/eigen/eigen/get/deff8b280204.tar.gz", ], - sha256 = "e6ec2502a5d82dd5df0b9b16e7697f5fccb81c322d0be8e3492969eecb66badd", - strip_prefix = "eigen-eigen-9c6361787292", + sha256 = "a39834683eb5bdb9a7434f0ab3621d2cbc3b07e8002db6de101e45ec536723eb", + strip_prefix = "eigen-eigen-deff8b280204", build_file = str(Label("//third_party:eigen.BUILD")), ) @@ -255,7 +292,7 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""): actual = "@six_archive//:six", ) - native.http_archive( + patched_http_archive( name = "protobuf", urls = [ "http://bazel-mirror.storage.googleapis.com/github.com/google/protobuf/archive/2b7430d96aeff2bb624c8d52182ff5e4b9f7f18a.tar.gz", @@ -263,6 +300,11 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""): ], sha256 = "e5d3d4e227a0f7afb8745df049bbd4d55474b158ca5aaa2a0e31099af24be1d0", strip_prefix = "protobuf-2b7430d96aeff2bb624c8d52182ff5e4b9f7f18a", + # TODO: remove patching when tensorflow stops linking same protos into + # multiple shared libraries loaded in runtime by python. + # This patch fixes a runtime crash when tensorflow is compiled + # with clang -O2 on Linux (see https://github.com/tensorflow/tensorflow/issues/8394) + patch_file = str(Label("//third_party/protobuf:add_noinlines.patch")), ) native.new_http_archive( @@ -452,7 +494,9 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""): ], sha256 = "6787f0eed88d52ee8e32956fa4947d92c139da469f1d8e311c307f27d641118e", strip_prefix = "nccl-024d1e267845f2ed06f3e2e42476d50f04a00ee6", - build_file = str(Label("//third_party:nccl.BUILD")), + build_file = str(Label("//third_party/nccl:nccl.BUILD")), + # TODO: Remove patching after the fix is merged into nccl(see https://github.com/NVIDIA/nccl/pull/78) + patch_file = str(Label("//third_party/nccl:fix_clang_compilation.patch")), repository = tf_repo_name, ) |