aboutsummaryrefslogtreecommitdiffhomepage
path: root/third_party/nccl/nccl_configure.bzl
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-05 03:09:27 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-05 03:11:33 -0700
commit1fda7645d132b71b9084b01945795e97e582adcd (patch)
tree7ae6dfa150c4b58dc339db0ae1f0b177c482433e /third_party/nccl/nccl_configure.bzl
parentdfa9921e6343727b05f42f8d4a918b19528ff994 (diff)
Add support for NCCL2. The configure script asks for what version of NCCL to use. The default is still NCCL 1 from GitHub. If the user chooses NCCL 2, it asks for the install directory.
The nccl_configure.bzl generates two different BUILD files based on the chose NCCL version. For NCCL 1, it aliases to the existing 'nccl_archive' http_repo on GitHub. For NCCL 2, it creates a target containing the NCCL 2 library and headers from the chosen install directory. PiperOrigin-RevId: 191718007
Diffstat (limited to 'third_party/nccl/nccl_configure.bzl')
-rw-r--r--third_party/nccl/nccl_configure.bzl172
1 files changed, 172 insertions, 0 deletions
diff --git a/third_party/nccl/nccl_configure.bzl b/third_party/nccl/nccl_configure.bzl
new file mode 100644
index 0000000000..9dfcb18369
--- /dev/null
+++ b/third_party/nccl/nccl_configure.bzl
@@ -0,0 +1,172 @@
+# -*- Python -*-
+"""Repository rule for NCCL configuration.
+
+`nccl_configure` depends on the following environment variables:
+
+ * `TF_NCCL_VERSION`: The NCCL version.
+ * `NCCL_INSTALL_PATH`: The installation path of the NCCL library.
+"""
+
+load(
+ "//third_party/gpus:cuda_configure.bzl",
+ "auto_configure_fail",
+ "find_cuda_define",
+ "matches_version",
+)
+
+_NCCL_INSTALL_PATH = "NCCL_INSTALL_PATH"
+_TF_NCCL_VERSION = "TF_NCCL_VERSION"
+
+_DEFINE_NCCL_MAJOR = "#define NCCL_MAJOR"
+_DEFINE_NCCL_MINOR = "#define NCCL_MINOR"
+_DEFINE_NCCL_PATCH = "#define NCCL_PATCH"
+
+_NCCL_DUMMY_BUILD_CONTENT = """
+filegroup(
+ name = "LICENSE",
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "nccl",
+ visibility = ["//visibility:public"],
+)
+"""
+
+_NCCL_ARCHIVE_BUILD_CONTENT = """
+filegroup(
+ name = "LICENSE",
+ data = ["@nccl_archive//:LICENSE.txt"],
+ visibility = ["//visibility:public"],
+)
+
+alias(
+ name = "nccl",
+ actual = "@nccl_archive//:nccl",
+ visibility = ["//visibility:public"],
+)
+"""
+
+_NCCL_LOCAL_BUILD_TEMPLATE = """
+filegroup(
+ name = "LICENSE",
+ data = ["nccl/NCCL-SLA.txt"],
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "nccl",
+ srcs = ["nccl/lib/libnccl.so.%s"],
+ hdrs = ["nccl/include/nccl.h"],
+ include_prefix = "third_party/nccl",
+ strip_include_prefix = "nccl/include",
+ deps = [
+ "@local_config_cuda//cuda:cuda_headers",
+ ],
+ visibility = ["//visibility:public"],
+)
+"""
+
+
+def _find_nccl_header(repository_ctx, nccl_install_path):
+ """Finds the NCCL header on the system.
+
+ Args:
+ repository_ctx: The repository context.
+ nccl_install_path: The NCCL library install directory.
+
+ Returns:
+ The path to the NCCL header.
+ """
+ header_path = repository_ctx.path("%s/include/nccl.h" % nccl_install_path)
+ if not header_path.exists:
+ auto_configure_fail("Cannot find %s" % str(header_path))
+ return header_path
+
+
+def _check_nccl_version(repository_ctx, nccl_install_path, nccl_version):
+ """Checks whether the header file matches the specified version of NCCL.
+
+ Args:
+ repository_ctx: The repository context.
+ nccl_install_path: The NCCL library install directory.
+ nccl_version: The expected NCCL version.
+
+ Returns:
+ A string containing the library version of NCCL.
+ """
+ header_path = _find_nccl_header(repository_ctx, nccl_install_path)
+ header_dir = str(header_path.realpath.dirname)
+ major_version = find_cuda_define(repository_ctx, header_dir, "nccl.h",
+ _DEFINE_NCCL_MAJOR)
+ minor_version = find_cuda_define(repository_ctx, header_dir, "nccl.h",
+ _DEFINE_NCCL_MINOR)
+ patch_version = find_cuda_define(repository_ctx, header_dir, "nccl.h",
+ _DEFINE_NCCL_PATCH)
+ header_version = "%s.%s.%s" % (major_version, minor_version, patch_version)
+ if not matches_version(nccl_version, header_version):
+ auto_configure_fail(
+ ("NCCL library version detected from %s/nccl.h (%s) does not match " +
+ "TF_NCCL_VERSION (%s). To fix this rerun configure again.") %
+ (header_dir, header_version, nccl_version))
+
+
+def _find_nccl_lib(repository_ctx, nccl_install_path, nccl_version):
+ """Finds the given NCCL library on the system.
+
+ Args:
+ repository_ctx: The repository context.
+ nccl_install_path: The NCCL library installation directory.
+ nccl_version: The version of NCCL library files as returned
+ by _nccl_version.
+
+ Returns:
+ The path to the NCCL library.
+ """
+ lib_path = repository_ctx.path("%s/lib/libnccl.so.%s" % (nccl_install_path,
+ nccl_version))
+ if not lib_path.exists:
+ auto_configure_fail("Cannot find NCCL library %s" % str(lib_path))
+ return lib_path
+
+
+def _nccl_configure_impl(repository_ctx):
+ """Implementation of the nccl_configure repository rule."""
+ if _TF_NCCL_VERSION not in repository_ctx.os.environ:
+ # Add a dummy build file to make bazel query happy.
+ repository_ctx.file("BUILD", _NCCL_DUMMY_BUILD_CONTENT)
+ return
+
+ nccl_version = repository_ctx.os.environ[_TF_NCCL_VERSION].strip()
+ if matches_version("1", nccl_version):
+ # Alias to GitHub target from @nccl_archive.
+ if not matches_version(nccl_version, "1.3"):
+ auto_configure_fail(
+ "NCCL from GitHub must use version 1.3 (got %s)" % nccl_version)
+ repository_ctx.file("BUILD", _NCCL_ARCHIVE_BUILD_CONTENT)
+ else:
+ # Create target for locally installed NCCL.
+ nccl_install_path = repository_ctx.os.environ[_NCCL_INSTALL_PATH].strip()
+ _check_nccl_version(repository_ctx, nccl_install_path, nccl_version)
+ repository_ctx.symlink(nccl_install_path, "nccl")
+ repository_ctx.file("BUILD", _NCCL_LOCAL_BUILD_TEMPLATE % nccl_version)
+
+
+nccl_configure = repository_rule(
+ implementation=_nccl_configure_impl,
+ environ=[
+ _NCCL_INSTALL_PATH,
+ _TF_NCCL_VERSION,
+ ],
+)
+"""Detects and configures the NCCL configuration.
+
+Add the following to your WORKSPACE FILE:
+
+```python
+nccl_configure(name = "local_config_nccl")
+```
+
+Args:
+ name: A unique name for this workspace rule.
+"""