aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/workspace.bzl37
-rw-r--r--third_party/tflite_mobilenet.BUILD13
2 files changed, 33 insertions, 17 deletions
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index afcae6eade..3081a8d1dc 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -1,21 +1,24 @@
# TensorFlow external dependencies that can be loaded in WORKSPACE files.
load("//third_party/gpus:cuda_configure.bzl", "cuda_configure")
+
load("//third_party/sycl:sycl_configure.bzl", "sycl_configure")
load("//third_party/mkl:build_defs.bzl", "mkl_repository")
-load("@io_bazel_rules_closure//closure/private:java_import_external.bzl",
- "java_import_external")
+load(
+ "@io_bazel_rules_closure//closure/private:java_import_external.bzl",
+ "java_import_external",
+)
load("@io_bazel_rules_closure//closure:defs.bzl", "filegroup_external")
load("//third_party/py:python_configure.bzl", "python_configure")
-load("//third_party/toolchains/cpus/arm:arm_compiler_configure.bzl",
- "arm_compiler_configure")
-
+load(
+ "//third_party/toolchains/cpus/arm:arm_compiler_configure.bzl",
+ "arm_compiler_configure",
+)
def _is_windows(repository_ctx):
"""Returns true if the host operating system is windows."""
return repository_ctx.os.name.lower().find("windows") != -1
-
def _get_env_var(repository_ctx, name):
"""Find an environment variable."""
if name in repository_ctx.os.environ:
@@ -23,7 +26,6 @@ def _get_env_var(repository_ctx, name):
else:
return None
-
# Parse the bazel version string from `native.bazel_version`.
def _parse_bazel_version(bazel_version):
# Remove commit from version.
@@ -39,7 +41,6 @@ def _parse_bazel_version(bazel_version):
version_tuple += (str(number),)
return version_tuple
-
# Check that a specific bazel version is being used.
def check_version(bazel_version):
if "bazel_version" not in dir(native):
@@ -56,11 +57,9 @@ def check_version(bazel_version):
fail("\nCurrent Bazel version is {}, expected at least {}\n".format(
native.bazel_version, bazel_version))
-
def _repos_are_siblings():
return Label("@foo//bar").workspace_root.startswith("../")
-
# Temporary workaround to support including TensorFlow as a submodule until this
# use-case is supported in the next Bazel release.
def _temp_workaround_http_archive_impl(repo_ctx):
@@ -73,9 +72,7 @@ def _temp_workaround_http_archive_impl(repo_ctx):
if repo_ctx.attr.patch_file != None:
_apply_patch(repo_ctx, repo_ctx.attr.patch_file)
-
temp_workaround_http_archive = repository_rule(
- implementation = _temp_workaround_http_archive_impl,
attrs = {
"build_file": attr.label(),
"repository": attr.string(),
@@ -84,6 +81,7 @@ temp_workaround_http_archive = repository_rule(
"sha256": attr.string(default = ""),
"strip_prefix": attr.string(default = ""),
},
+ implementation = _temp_workaround_http_archive_impl,
)
# Executes specified command with arguments and calls 'fail' if it exited with
@@ -95,7 +93,6 @@ def _execute_and_check_ret_code(repo_ctx, cmd_and_args):
+ "Stderr: {3}").format(" ".join(cmd_and_args), result.return_code,
result.stdout, result.stderr))
-
# Apply a patch_file to the repository root directory
# Runs 'patch -p1'
def _apply_patch(repo_ctx, patch_file):
@@ -113,7 +110,6 @@ def _apply_patch(repo_ctx, patch_file):
cmd = [bazel_sh, "-c", " ".join(cmd)]
_execute_and_check_ret_code(repo_ctx, cmd)
-
# Download the repository and apply a patch to its root
def _patched_http_archive_impl(repo_ctx):
repo_ctx.download_and_extract(
@@ -122,9 +118,7 @@ def _patched_http_archive_impl(repo_ctx):
stripPrefix=repo_ctx.attr.strip_prefix)
_apply_patch(repo_ctx, repo_ctx.attr.patch_file)
-
patched_http_archive = repository_rule(
- implementation = _patched_http_archive_impl,
attrs = {
"patch_file": attr.label(),
"build_file": attr.label(),
@@ -133,9 +127,9 @@ patched_http_archive = repository_rule(
"sha256": attr.string(default = ""),
"strip_prefix": attr.string(default = ""),
},
+ implementation = _patched_http_archive_impl,
)
-
# If TensorFlow is linked as a submodule.
# path_prefix is no longer used.
# tf_repo_name is thought to be under consideration.
@@ -821,3 +815,12 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
"https://github.com/google/flatbuffers/archive/971a68110e4fc1bace10fcb6deeb189e7e1a34ce.tar.gz",
],
)
+
+ native.new_http_archive(
+ name = "tflite_mobilenet",
+ build_file = str(Label("//third_party:tflite_mobilenet.BUILD")),
+ sha256 = "eb71679d23a0cbdb173b36ea39f3d3096de0a9b0410d148a8237f20cc1157a61",
+ urls = [
+ "https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_1.0_224_quantized_2017_11_01.zip"
+ ],
+ )
diff --git a/third_party/tflite_mobilenet.BUILD b/third_party/tflite_mobilenet.BUILD
new file mode 100644
index 0000000000..75663eff48
--- /dev/null
+++ b/third_party/tflite_mobilenet.BUILD
@@ -0,0 +1,13 @@
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+filegroup(
+ name = "model_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "BUILD",
+ ],
+ ),
+)