aboutsummaryrefslogtreecommitdiffhomepage
path: root/third_party/repo.bzl
blob: 36f5aa5bdee43a511abf5634af85643ac7e11cfc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# Copyright 2017 The TensorFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities for defining TensorFlow Bazel dependencies."""

_SINGLE_URL_WHITELIST = depset([
    "arm_compiler",
    "ortools_archive",
    "gemmlowp",
])

def _is_windows(ctx):
  return ctx.os.name.lower().find("windows") != -1

def _wrap_bash_cmd(ctx, cmd):
  if _is_windows(ctx):
    bazel_sh = _get_env_var(ctx, "BAZEL_SH")
    if not bazel_sh:
      fail("BAZEL_SH environment variable is not set")
    cmd = [bazel_sh, "-l", "-c", " ".join(cmd)]
  return cmd

def _get_env_var(ctx, name):
  if name in ctx.os.environ:
    return ctx.os.environ[name]
  else:
    return None

# Executes specified command with arguments and calls 'fail' if it exited with
# non-zero code
def _execute_and_check_ret_code(repo_ctx, cmd_and_args):
  result = repo_ctx.execute(cmd_and_args, timeout=10)
  if result.return_code != 0:
    fail(("Non-zero return code({1}) when executing '{0}':\n" + "Stdout: {2}\n"
          + "Stderr: {3}").format(" ".join(cmd_and_args), result.return_code,
                                  result.stdout, result.stderr))

def _repos_are_siblings():
  return Label("@foo//bar").workspace_root.startswith("../")

# Apply a patch_file to the repository root directory
# Runs 'patch -p1'
def _apply_patch(ctx, patch_file):
  # Don't check patch on Windows, because patch is only available under bash.
  if not _is_windows(ctx) and not ctx.which("patch"):
    fail("patch command is not found, please install it")
  cmd = _wrap_bash_cmd(
    ctx, ["patch", "-p1", "-d", ctx.path("."), "-i", ctx.path(patch_file)])
  _execute_and_check_ret_code(ctx, cmd)

def _apply_delete(ctx, paths):
  for path in paths:
    if path.startswith("/"):
      fail("refusing to rm -rf path starting with '/': " + path)
    if ".." in path:
      fail("refusing to rm -rf path containing '..': " + path)
  cmd = _wrap_bash_cmd(ctx, ["rm", "-rf"] + [ctx.path(path) for path in paths])
  _execute_and_check_ret_code(ctx, cmd)

def _tf_http_archive(ctx):
  if ("mirror.bazel.build" not in ctx.attr.urls[0] and
      (len(ctx.attr.urls) < 2 and
       ctx.attr.name not in _SINGLE_URL_WHITELIST)):
    fail("tf_http_archive(urls) must have redundant URLs. The " +
         "mirror.bazel.build URL must be present and it must come first. " +
         "Even if you don't have permission to mirror the file, please " +
         "put the correctly formatted mirror URL there anyway, because " +
         "someone will come along shortly thereafter and mirror the file.")
  ctx.download_and_extract(
      ctx.attr.urls,
      "",
      ctx.attr.sha256,
      ctx.attr.type,
      ctx.attr.strip_prefix)
  if ctx.attr.delete:
    _apply_delete(ctx, ctx.attr.delete)
  if ctx.attr.patch_file != None:
    _apply_patch(ctx, ctx.attr.patch_file)
  if ctx.attr.build_file != None:
    ctx.template("BUILD", ctx.attr.build_file, {
        "%prefix%": ".." if _repos_are_siblings() else "external",
    }, False)

tf_http_archive = repository_rule(
    implementation=_tf_http_archive,
    attrs={
        "sha256": attr.string(mandatory=True),
        "urls": attr.string_list(mandatory=True, allow_empty=False),
        "strip_prefix": attr.string(),
        "type": attr.string(),
        "delete": attr.string_list(),
        "patch_file": attr.label(),
        "build_file": attr.label(),
    })
"""Downloads and creates Bazel repos for dependencies.

This is a swappable replacement for both http_archive() and
new_http_archive() that offers some additional features. It also helps
ensure best practices are followed.
"""