aboutsummaryrefslogtreecommitdiffhomepage
path: root/third_party/nccl/build_defs.bzl.tpl
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/nccl/build_defs.bzl.tpl')
-rw-r--r--third_party/nccl/build_defs.bzl.tpl210
1 files changed, 210 insertions, 0 deletions
diff --git a/third_party/nccl/build_defs.bzl.tpl b/third_party/nccl/build_defs.bzl.tpl
new file mode 100644
index 0000000000..ede1d3dad5
--- /dev/null
+++ b/third_party/nccl/build_defs.bzl.tpl
@@ -0,0 +1,210 @@
+"""Repository rule for NCCL."""
+
+load("@local_config_cuda//cuda:build_defs.bzl", "cuda_default_copts")
+
+def _gen_nccl_h_impl(ctx):
+ """Creates nccl.h from a template."""
+ ctx.actions.expand_template(
+ output = ctx.outputs.output,
+ template = ctx.file.template,
+ substitutions = {
+ "${nccl:Major}": "2",
+ "${nccl:Minor}": "3",
+ "${nccl:Patch}": "5",
+ "${nccl:Suffix}": "",
+ "${nccl:Version}": "2305",
+ },
+ )
+gen_nccl_h = rule(
+ implementation = _gen_nccl_h_impl,
+ attrs = {
+ "template": attr.label(allow_single_file = True),
+ "output": attr.output(),
+ },
+)
+"""Creates the NCCL header file."""
+
+
+def _process_srcs_impl(ctx):
+ """Appends .cc to .cu files, patches include directives."""
+ files = []
+ for src in ctx.files.srcs:
+ if not src.is_source:
+ # Process only once, specifically "src/nccl.h".
+ files.append(src)
+ continue
+ name = src.basename
+ if src.extension == "cu":
+ name = ctx.attr.prefix + name + ".cc"
+ file = ctx.actions.declare_file(name, sibling = src)
+ ctx.actions.expand_template(
+ output = file,
+ template = src,
+ substitutions = {
+ "\"collectives.h": "\"collectives/collectives.h",
+ "\"../collectives.h": "\"collectives/collectives.h",
+ "#if __CUDACC_VER_MAJOR__":
+ "#if defined __CUDACC_VER_MAJOR__ && __CUDACC_VER_MAJOR__",
+ # Substitutions are applied in order.
+ "std::nullptr_t": "nullptr_t",
+ "nullptr_t": "std::nullptr_t",
+ },
+ )
+ files.append(file)
+ return [DefaultInfo(files = depset(files))]
+_process_srcs = rule(
+ implementation = _process_srcs_impl,
+ attrs = {
+ "srcs": attr.label_list(allow_files = True),
+ "prefix": attr.string(default = ""),
+ },
+)
+"""Processes the NCCL srcs so they can be compiled with bazel and clang."""
+
+
+def nccl_library(name, srcs=None, hdrs=None, prefix=None, **kwargs):
+ """Processes the srcs and hdrs and creates a cc_library."""
+
+ _process_srcs(
+ name = name + "_srcs",
+ srcs = srcs,
+ prefix = prefix,
+ )
+ _process_srcs(
+ name = name + "_hdrs",
+ srcs = hdrs,
+ )
+
+ native.cc_library(
+ name = name,
+ srcs = [name + "_srcs"] if srcs else [],
+ hdrs = [name + "_hdrs"] if hdrs else [],
+ **kwargs
+ )
+
+
+def rdc_copts():
+ """Returns copts for compiling relocatable device code."""
+
+ # The global functions can not have a lower register count than the
+ # device functions. This is enforced by setting a fixed register count.
+ # https://github.com/NVIDIA/nccl/blob/f93fe9bfd94884cec2ba711897222e0df5569a53/makefiles/common.mk#L48
+ maxrregcount = "-maxrregcount=96"
+
+ return cuda_default_copts() + select({
+ "@local_config_cuda//cuda:using_nvcc": [
+ "-nvcc_options",
+ "relocatable-device-code=true",
+ "-nvcc_options",
+ "ptxas-options=" + maxrregcount,
+ ],
+ "@local_config_cuda//cuda:using_clang": [
+ "-fcuda-rdc",
+ "-Xcuda-ptxas",
+ maxrregcount,
+ ],
+ "//conditions:default": [],
+ }) + ["-fvisibility=hidden"]
+
+
+def _filter_impl(ctx):
+ suffix = ctx.attr.suffix
+ files = [src for src in ctx.files.srcs if src.path.endswith(suffix)]
+ return [DefaultInfo(files = depset(files))]
+_filter = rule(
+ implementation = _filter_impl,
+ attrs = {
+ "srcs": attr.label_list(allow_files = True),
+ "suffix": attr.string(),
+ },
+)
+"""Filters the srcs to the ones ending with suffix."""
+
+
+def _gen_link_src_impl(ctx):
+ ctx.actions.expand_template(
+ output = ctx.outputs.output,
+ template = ctx.file.template,
+ substitutions = {
+ "REGISTERLINKBINARYFILE": '"%s"' % ctx.file.register_hdr.short_path,
+ "FATBINFILE": '"%s"' % ctx.file.fatbin_hdr.short_path,
+ },
+ )
+_gen_link_src = rule(
+ implementation = _gen_link_src_impl,
+ attrs = {
+ "register_hdr": attr.label(allow_single_file = True),
+ "fatbin_hdr": attr.label(allow_single_file = True),
+ "template": attr.label(allow_single_file = True),
+ "output": attr.output(),
+ },
+)
+"""Patches the include directives for the link.stub file."""
+
+
+def device_link(name, srcs):
+ """Links seperately compiled relocatable device code into a cc_library."""
+
+ # From .a and .pic.a archives, just use the latter.
+ _filter(
+ name = name + "_pic_a",
+ srcs = srcs,
+ suffix = ".pic.a",
+ )
+
+ # Device-link to cubins for each architecture.
+ images = []
+ cubins = []
+ for arch in %{gpu_architectures}:
+ cubin = "%s_%s.cubin" % (name, arch)
+ register_hdr = "%s_%s.h" % (name, arch)
+ nvlink = "@local_config_nccl//:nvlink"
+ cmd = ("$(location %s) --cpu-arch=X86_64 " % nvlink +
+ "--arch=%s $(SRCS) " % arch +
+ "--register-link-binaries=$(location %s) " % register_hdr +
+ "--output-file=$(location %s)" % cubin)
+ native.genrule(
+ name = "%s_%s" % (name, arch),
+ outs = [register_hdr, cubin],
+ srcs = [name + "_pic_a"],
+ cmd = cmd,
+ tools = [nvlink],
+ )
+ images.append("--image=profile=%s,file=$(location %s)" % (arch, cubin))
+ cubins.append(cubin)
+
+ # Generate fatbin header from all cubins.
+ fatbin_hdr = name + ".fatbin.h"
+ fatbinary = "@local_config_nccl//:cuda/bin/fatbinary"
+ cmd = ("PATH=$$CUDA_TOOLKIT_PATH/bin:$$PATH " + # for bin2c
+ "$(location %s) -64 --cmdline=--compile-only --link " % fatbinary +
+ "--compress-all %s --create=%%{name}.fatbin " % " ".join(images) +
+ "--embedded-fatbin=$@")
+ native.genrule(
+ name = name + "_fatbin_h",
+ outs = [fatbin_hdr],
+ srcs = cubins,
+ cmd = cmd,
+ tools = [fatbinary],
+ )
+
+ # Generate the source file #including the headers generated above.
+ _gen_link_src(
+ name = name + "_cc",
+ # Include just the last one, they are equivalent.
+ register_hdr = register_hdr,
+ fatbin_hdr = fatbin_hdr,
+ template = "@local_config_nccl//:cuda/bin/crt/link.stub",
+ output = name + ".cc",
+ )
+
+ # Compile the source file into the cc_library.
+ native.cc_library(
+ name = name,
+ srcs = [name + "_cc"],
+ textual_hdrs = [register_hdr, fatbin_hdr],
+ deps = [
+ "@local_config_cuda//cuda:cuda_headers",
+ "@local_config_cuda//cuda:cudart_static",
+ ],
+ )