diff options
-rw-r--r-- | tensorflow/workspace.bzl | 37 | ||||
-rw-r--r-- | third_party/tflite_mobilenet.BUILD | 13 |
2 files changed, 33 insertions, 17 deletions
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index afcae6eade..3081a8d1dc 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -1,21 +1,24 @@ # TensorFlow external dependencies that can be loaded in WORKSPACE files. load("//third_party/gpus:cuda_configure.bzl", "cuda_configure") + load("//third_party/sycl:sycl_configure.bzl", "sycl_configure") load("//third_party/mkl:build_defs.bzl", "mkl_repository") -load("@io_bazel_rules_closure//closure/private:java_import_external.bzl", - "java_import_external") +load( + "@io_bazel_rules_closure//closure/private:java_import_external.bzl", + "java_import_external", +) load("@io_bazel_rules_closure//closure:defs.bzl", "filegroup_external") load("//third_party/py:python_configure.bzl", "python_configure") -load("//third_party/toolchains/cpus/arm:arm_compiler_configure.bzl", - "arm_compiler_configure") - +load( + "//third_party/toolchains/cpus/arm:arm_compiler_configure.bzl", + "arm_compiler_configure", +) def _is_windows(repository_ctx): """Returns true if the host operating system is windows.""" return repository_ctx.os.name.lower().find("windows") != -1 - def _get_env_var(repository_ctx, name): """Find an environment variable.""" if name in repository_ctx.os.environ: @@ -23,7 +26,6 @@ def _get_env_var(repository_ctx, name): else: return None - # Parse the bazel version string from `native.bazel_version`. def _parse_bazel_version(bazel_version): # Remove commit from version. @@ -39,7 +41,6 @@ def _parse_bazel_version(bazel_version): version_tuple += (str(number),) return version_tuple - # Check that a specific bazel version is being used. def check_version(bazel_version): if "bazel_version" not in dir(native): @@ -56,11 +57,9 @@ def check_version(bazel_version): fail("\nCurrent Bazel version is {}, expected at least {}\n".format( native.bazel_version, bazel_version)) - def _repos_are_siblings(): return Label("@foo//bar").workspace_root.startswith("../") - # Temporary workaround to support including TensorFlow as a submodule until this # use-case is supported in the next Bazel release. def _temp_workaround_http_archive_impl(repo_ctx): @@ -73,9 +72,7 @@ def _temp_workaround_http_archive_impl(repo_ctx): if repo_ctx.attr.patch_file != None: _apply_patch(repo_ctx, repo_ctx.attr.patch_file) - temp_workaround_http_archive = repository_rule( - implementation = _temp_workaround_http_archive_impl, attrs = { "build_file": attr.label(), "repository": attr.string(), @@ -84,6 +81,7 @@ temp_workaround_http_archive = repository_rule( "sha256": attr.string(default = ""), "strip_prefix": attr.string(default = ""), }, + implementation = _temp_workaround_http_archive_impl, ) # Executes specified command with arguments and calls 'fail' if it exited with @@ -95,7 +93,6 @@ def _execute_and_check_ret_code(repo_ctx, cmd_and_args): + "Stderr: {3}").format(" ".join(cmd_and_args), result.return_code, result.stdout, result.stderr)) - # Apply a patch_file to the repository root directory # Runs 'patch -p1' def _apply_patch(repo_ctx, patch_file): @@ -113,7 +110,6 @@ def _apply_patch(repo_ctx, patch_file): cmd = [bazel_sh, "-c", " ".join(cmd)] _execute_and_check_ret_code(repo_ctx, cmd) - # Download the repository and apply a patch to its root def _patched_http_archive_impl(repo_ctx): repo_ctx.download_and_extract( @@ -122,9 +118,7 @@ def _patched_http_archive_impl(repo_ctx): stripPrefix=repo_ctx.attr.strip_prefix) _apply_patch(repo_ctx, repo_ctx.attr.patch_file) - patched_http_archive = repository_rule( - implementation = _patched_http_archive_impl, attrs = { "patch_file": attr.label(), "build_file": attr.label(), @@ -133,9 +127,9 @@ patched_http_archive = repository_rule( "sha256": attr.string(default = ""), "strip_prefix": attr.string(default = ""), }, + implementation = _patched_http_archive_impl, ) - # If TensorFlow is linked as a submodule. # path_prefix is no longer used. # tf_repo_name is thought to be under consideration. @@ -821,3 +815,12 @@ def tf_workspace(path_prefix="", tf_repo_name=""): "https://github.com/google/flatbuffers/archive/971a68110e4fc1bace10fcb6deeb189e7e1a34ce.tar.gz", ], ) + + native.new_http_archive( + name = "tflite_mobilenet", + build_file = str(Label("//third_party:tflite_mobilenet.BUILD")), + sha256 = "eb71679d23a0cbdb173b36ea39f3d3096de0a9b0410d148a8237f20cc1157a61", + urls = [ + "https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_1.0_224_quantized_2017_11_01.zip" + ], + ) diff --git a/third_party/tflite_mobilenet.BUILD b/third_party/tflite_mobilenet.BUILD new file mode 100644 index 0000000000..75663eff48 --- /dev/null +++ b/third_party/tflite_mobilenet.BUILD @@ -0,0 +1,13 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +filegroup( + name = "model_files", + srcs = glob( + ["**/*"], + exclude = [ + "BUILD", + ], + ), +) |