From 1357711e6faf688f863821c35be6c358891616ec Mon Sep 17 00:00:00 2001 From: Austin Anderson Date: Wed, 15 Aug 2018 17:33:29 -0700 Subject: Create new interface for specifying TensorFlow third party imports PiperOrigin-RevId: 208913262 --- third_party/repo.bzl | 229 +++++++++++++++++++++++++++++++++------------------ 1 file changed, 151 insertions(+), 78 deletions(-) (limited to 'third_party/repo.bzl') diff --git a/third_party/repo.bzl b/third_party/repo.bzl index 5cb42691c5..7d1aa5dce9 100644 --- a/third_party/repo.bzl +++ b/third_party/repo.bzl @@ -19,104 +19,111 @@ _SINGLE_URL_WHITELIST = depset([ ]) def _is_windows(ctx): - return ctx.os.name.lower().find("windows") != -1 + return ctx.os.name.lower().find("windows") != -1 def _wrap_bash_cmd(ctx, cmd): - if _is_windows(ctx): - bazel_sh = _get_env_var(ctx, "BAZEL_SH") - if not bazel_sh: - fail("BAZEL_SH environment variable is not set") - cmd = [bazel_sh, "-l", "-c", " ".join(cmd)] - return cmd + if _is_windows(ctx): + bazel_sh = _get_env_var(ctx, "BAZEL_SH") + if not bazel_sh: + fail("BAZEL_SH environment variable is not set") + cmd = [bazel_sh, "-l", "-c", " ".join(cmd)] + return cmd def _get_env_var(ctx, name): - if name in ctx.os.environ: - return ctx.os.environ[name] - else: - return None + if name in ctx.os.environ: + return ctx.os.environ[name] + else: + return None # Checks if we should use the system lib instead of the bundled one def _use_system_lib(ctx, name): - syslibenv = _get_env_var(ctx, "TF_SYSTEM_LIBS") - if syslibenv: - for n in syslibenv.strip().split(","): - if n.strip() == name: - return True - return False + syslibenv = _get_env_var(ctx, "TF_SYSTEM_LIBS") + if syslibenv: + for n in syslibenv.strip().split(","): + if n.strip() == name: + return True + return False # Executes specified command with arguments and calls 'fail' if it exited with # non-zero code def _execute_and_check_ret_code(repo_ctx, cmd_and_args): - result = repo_ctx.execute(cmd_and_args, timeout=10) - if result.return_code != 0: - fail(("Non-zero return code({1}) when executing '{0}':\n" + "Stdout: {2}\n" - + "Stderr: {3}").format(" ".join(cmd_and_args), result.return_code, - result.stdout, result.stderr)) + result = repo_ctx.execute(cmd_and_args, timeout = 10) + if result.return_code != 0: + fail(("Non-zero return code({1}) when executing '{0}':\n" + "Stdout: {2}\n" + + "Stderr: {3}").format( + " ".join(cmd_and_args), + result.return_code, + result.stdout, + result.stderr, + )) def _repos_are_siblings(): - return Label("@foo//bar").workspace_root.startswith("../") + return Label("@foo//bar").workspace_root.startswith("../") # Apply a patch_file to the repository root directory # Runs 'patch -p1' def _apply_patch(ctx, patch_file): - # Don't check patch on Windows, because patch is only available under bash. - if not _is_windows(ctx) and not ctx.which("patch"): - fail("patch command is not found, please install it") - cmd = _wrap_bash_cmd( - ctx, ["patch", "-p1", "-d", ctx.path("."), "-i", ctx.path(patch_file)]) - _execute_and_check_ret_code(ctx, cmd) + # Don't check patch on Windows, because patch is only available under bash. + if not _is_windows(ctx) and not ctx.which("patch"): + fail("patch command is not found, please install it") + cmd = _wrap_bash_cmd( + ctx, + ["patch", "-p1", "-d", ctx.path("."), "-i", ctx.path(patch_file)], + ) + _execute_and_check_ret_code(ctx, cmd) def _apply_delete(ctx, paths): - for path in paths: - if path.startswith("/"): - fail("refusing to rm -rf path starting with '/': " + path) - if ".." in path: - fail("refusing to rm -rf path containing '..': " + path) - cmd = _wrap_bash_cmd(ctx, ["rm", "-rf"] + [ctx.path(path) for path in paths]) - _execute_and_check_ret_code(ctx, cmd) + for path in paths: + if path.startswith("/"): + fail("refusing to rm -rf path starting with '/': " + path) + if ".." in path: + fail("refusing to rm -rf path containing '..': " + path) + cmd = _wrap_bash_cmd(ctx, ["rm", "-rf"] + [ctx.path(path) for path in paths]) + _execute_and_check_ret_code(ctx, cmd) def _tf_http_archive(ctx): - if ("mirror.bazel.build" not in ctx.attr.urls[0] and - (len(ctx.attr.urls) < 2 and - ctx.attr.name not in _SINGLE_URL_WHITELIST)): - fail("tf_http_archive(urls) must have redundant URLs. The " + - "mirror.bazel.build URL must be present and it must come first. " + - "Even if you don't have permission to mirror the file, please " + - "put the correctly formatted mirror URL there anyway, because " + - "someone will come along shortly thereafter and mirror the file.") - - use_syslib = _use_system_lib(ctx, ctx.attr.name) - if not use_syslib: - ctx.download_and_extract( - ctx.attr.urls, - "", - ctx.attr.sha256, - ctx.attr.type, - ctx.attr.strip_prefix) - if ctx.attr.delete: - _apply_delete(ctx, ctx.attr.delete) - if ctx.attr.patch_file != None: - _apply_patch(ctx, ctx.attr.patch_file) - - if use_syslib and ctx.attr.system_build_file != None: - # Use BUILD.bazel to avoid conflict with third party projects with - # BUILD or build (directory) underneath. - ctx.template("BUILD.bazel", ctx.attr.system_build_file, { - "%prefix%": ".." if _repos_are_siblings() else "external", - }, False) - - elif ctx.attr.build_file != None: - # Use BUILD.bazel to avoid conflict with third party projects with - # BUILD or build (directory) underneath. - ctx.template("BUILD.bazel", ctx.attr.build_file, { - "%prefix%": ".." if _repos_are_siblings() else "external", - }, False) + if ("mirror.bazel.build" not in ctx.attr.urls[0] and + (len(ctx.attr.urls) < 2 and + ctx.attr.name not in _SINGLE_URL_WHITELIST)): + fail("tf_http_archive(urls) must have redundant URLs. The " + + "mirror.bazel.build URL must be present and it must come first. " + + "Even if you don't have permission to mirror the file, please " + + "put the correctly formatted mirror URL there anyway, because " + + "someone will come along shortly thereafter and mirror the file.") + + use_syslib = _use_system_lib(ctx, ctx.attr.name) + if not use_syslib: + ctx.download_and_extract( + ctx.attr.urls, + "", + ctx.attr.sha256, + ctx.attr.type, + ctx.attr.strip_prefix, + ) + if ctx.attr.delete: + _apply_delete(ctx, ctx.attr.delete) + if ctx.attr.patch_file != None: + _apply_patch(ctx, ctx.attr.patch_file) + + if use_syslib and ctx.attr.system_build_file != None: + # Use BUILD.bazel to avoid conflict with third party projects with + # BUILD or build (directory) underneath. + ctx.template("BUILD.bazel", ctx.attr.system_build_file, { + "%prefix%": ".." if _repos_are_siblings() else "external", + }, False) + + elif ctx.attr.build_file != None: + # Use BUILD.bazel to avoid conflict with third party projects with + # BUILD or build (directory) underneath. + ctx.template("BUILD.bazel", ctx.attr.build_file, { + "%prefix%": ".." if _repos_are_siblings() else "external", + }, False) tf_http_archive = repository_rule( - implementation=_tf_http_archive, - attrs={ - "sha256": attr.string(mandatory=True), - "urls": attr.string_list(mandatory=True, allow_empty=False), + implementation = _tf_http_archive, + attrs = { + "sha256": attr.string(mandatory = True), + "urls": attr.string_list(mandatory = True, allow_empty = False), "strip_prefix": attr.string(), "type": attr.string(), "delete": attr.string_list(), @@ -124,12 +131,78 @@ tf_http_archive = repository_rule( "build_file": attr.label(), "system_build_file": attr.label(), }, - environ=[ - "TF_SYSTEM_LIBS", - ]) + environ = [ + "TF_SYSTEM_LIBS", + ], +) """Downloads and creates Bazel repos for dependencies. This is a swappable replacement for both http_archive() and new_http_archive() that offers some additional features. It also helps ensure best practices are followed. """ + +def _third_party_http_archive(ctx): + if ("mirror.bazel.build" not in ctx.attr.urls[0] and + (len(ctx.attr.urls) < 2 and + ctx.attr.name not in _SINGLE_URL_WHITELIST)): + fail("tf_http_archive(urls) must have redundant URLs. The " + + "mirror.bazel.build URL must be present and it must come first. " + + "Even if you don't have permission to mirror the file, please " + + "put the correctly formatted mirror URL there anyway, because " + + "someone will come along shortly thereafter and mirror the file.") + + use_syslib = _use_system_lib(ctx, ctx.attr.name) + + # Use "BUILD.bazel" to avoid conflict with third party projects that contain a + # file or directory called "BUILD" + buildfile_path = ctx.path("BUILD.bazel") + + if use_syslib: + if ctx.attr.system_build_file == None: + fail("Bazel was configured with TF_SYSTEM_LIBS to use a system " + + "library for %s, but no system build file for %s was configured. " + + "Please add a system_build_file attribute to the repository rule" + + "for %s." % (ctx.attr.name, ctx.attr.name, ctx.attr.name)) + ctx.symlink(Label(ctx.attr.system_build_file), buildfile_path) + + else: + ctx.download_and_extract( + ctx.attr.urls, + "", + ctx.attr.sha256, + ctx.attr.type, + ctx.attr.strip_prefix, + ) + if ctx.attr.delete: + _apply_delete(ctx, ctx.attr.delete) + if ctx.attr.patch_file != None: + _apply_patch(ctx, ctx.attr.patch_file) + ctx.symlink(Label(ctx.attr.build_file), buildfile_path) + + for internal_src, external_dest in ctx.attr.link_files.items(): + ctx.symlink(Label(internal_src), ctx.path(external_dest)) + +# Downloads and creates Bazel repos for dependencies. +# +# This is an upgrade for tf_http_archive that works with go/tfbr-thirdparty. +# +# For link_files, specify each dict entry as: +# "//path/to/source:file": "localfile" +third_party_http_archive = repository_rule( + implementation = _third_party_http_archive, + attrs = { + "sha256": attr.string(mandatory = True), + "urls": attr.string_list(mandatory = True, allow_empty = False), + "strip_prefix": attr.string(), + "type": attr.string(), + "delete": attr.string_list(), + "build_file": attr.string(mandatory = True), + "system_build_file": attr.string(mandatory = False), + "patch_file": attr.label(), + "link_files": attr.string_dict(), + }, + environ = [ + "TF_SYSTEM_LIBS", + ], +) -- cgit v1.2.3