diff options
Diffstat (limited to 'tensorflow/workspace.bzl')
-rw-r--r-- | tensorflow/workspace.bzl | 45 |
1 files changed, 21 insertions, 24 deletions
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index dfe332b091..afcae6eade 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -1,24 +1,21 @@ # TensorFlow external dependencies that can be loaded in WORKSPACE files. load("//third_party/gpus:cuda_configure.bzl", "cuda_configure") - load("//third_party/sycl:sycl_configure.bzl", "sycl_configure") load("//third_party/mkl:build_defs.bzl", "mkl_repository") -load( - "@io_bazel_rules_closure//closure/private:java_import_external.bzl", - "java_import_external", -) +load("@io_bazel_rules_closure//closure/private:java_import_external.bzl", + "java_import_external") load("@io_bazel_rules_closure//closure:defs.bzl", "filegroup_external") load("//third_party/py:python_configure.bzl", "python_configure") -load( - "//third_party/toolchains/cpus/arm:arm_compiler_configure.bzl", - "arm_compiler_configure", -) +load("//third_party/toolchains/cpus/arm:arm_compiler_configure.bzl", + "arm_compiler_configure") + def _is_windows(repository_ctx): """Returns true if the host operating system is windows.""" return repository_ctx.os.name.lower().find("windows") != -1 + def _get_env_var(repository_ctx, name): """Find an environment variable.""" if name in repository_ctx.os.environ: @@ -26,6 +23,7 @@ def _get_env_var(repository_ctx, name): else: return None + # Parse the bazel version string from `native.bazel_version`. def _parse_bazel_version(bazel_version): # Remove commit from version. @@ -41,6 +39,7 @@ def _parse_bazel_version(bazel_version): version_tuple += (str(number),) return version_tuple + # Check that a specific bazel version is being used. def check_version(bazel_version): if "bazel_version" not in dir(native): @@ -57,9 +56,11 @@ def check_version(bazel_version): fail("\nCurrent Bazel version is {}, expected at least {}\n".format( native.bazel_version, bazel_version)) + def _repos_are_siblings(): return Label("@foo//bar").workspace_root.startswith("../") + # Temporary workaround to support including TensorFlow as a submodule until this # use-case is supported in the next Bazel release. def _temp_workaround_http_archive_impl(repo_ctx): @@ -72,7 +73,9 @@ def _temp_workaround_http_archive_impl(repo_ctx): if repo_ctx.attr.patch_file != None: _apply_patch(repo_ctx, repo_ctx.attr.patch_file) + temp_workaround_http_archive = repository_rule( + implementation = _temp_workaround_http_archive_impl, attrs = { "build_file": attr.label(), "repository": attr.string(), @@ -81,7 +84,6 @@ temp_workaround_http_archive = repository_rule( "sha256": attr.string(default = ""), "strip_prefix": attr.string(default = ""), }, - implementation = _temp_workaround_http_archive_impl, ) # Executes specified command with arguments and calls 'fail' if it exited with @@ -93,6 +95,7 @@ def _execute_and_check_ret_code(repo_ctx, cmd_and_args): + "Stderr: {3}").format(" ".join(cmd_and_args), result.return_code, result.stdout, result.stderr)) + # Apply a patch_file to the repository root directory # Runs 'patch -p1' def _apply_patch(repo_ctx, patch_file): @@ -110,6 +113,7 @@ def _apply_patch(repo_ctx, patch_file): cmd = [bazel_sh, "-c", " ".join(cmd)] _execute_and_check_ret_code(repo_ctx, cmd) + # Download the repository and apply a patch to its root def _patched_http_archive_impl(repo_ctx): repo_ctx.download_and_extract( @@ -118,7 +122,9 @@ def _patched_http_archive_impl(repo_ctx): stripPrefix=repo_ctx.attr.strip_prefix) _apply_patch(repo_ctx, repo_ctx.attr.patch_file) + patched_http_archive = repository_rule( + implementation = _patched_http_archive_impl, attrs = { "patch_file": attr.label(), "build_file": attr.label(), @@ -127,9 +133,9 @@ patched_http_archive = repository_rule( "sha256": attr.string(default = ""), "strip_prefix": attr.string(default = ""), }, - implementation = _patched_http_archive_impl, ) + # If TensorFlow is linked as a submodule. # path_prefix is no longer used. # tf_repo_name is thought to be under consideration. @@ -442,11 +448,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.http_archive( name = "nsync", urls = [ - "https://mirror.bazel.build/github.com/google/nsync/archive/93815892dddafe9146a5f7e7042281d59d0f4323.tar.gz", - # "https://github.com/google/nsync/archive/93815892dddafe9146a5f7e7042281d59d0f4323.tar.gz", + "https://mirror.bazel.build/github.com/google/nsync/archive/4fc8ff3e7626c5f24bc9674438d8257f0ffc226c.tar.gz", + # "https://github.com/google/nsync/archive/4fc8ff3e7626c5f24bc9674438d8257f0ffc226c.tar.gz", ], - sha256 = "e3bd4555415ace511338fc27e595351738eea4e9006f1612b76c82914770716b", - strip_prefix = "nsync-93815892dddafe9146a5f7e7042281d59d0f4323", + sha256 = "ffbbe828f3d0bef75462e34801de5cea31d10aa63eaa42a4ed74c46521bdfd58", + strip_prefix = "nsync-4fc8ff3e7626c5f24bc9674438d8257f0ffc226c", ) native.http_archive( @@ -815,12 +821,3 @@ def tf_workspace(path_prefix="", tf_repo_name=""): "https://github.com/google/flatbuffers/archive/971a68110e4fc1bace10fcb6deeb189e7e1a34ce.tar.gz", ], ) - - native.new_http_archive( - name = "tflite_mobilenet", - build_file = str(Label("//third_party:tflite_mobilenet.BUILD")), - sha256 = "eb71679d23a0cbdb173b36ea39f3d3096de0a9b0410d148a8237f20cc1157a61", - urls = [ - "https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_1.0_224_quantized_2017_11_01.zip" - ], - ) |