aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-03-17 11:39:58 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-17 12:57:44 -0700
commit360f449d95cf487fd35dbcbc548a6b65fa7ae64f (patch)
tree5ec1f6a645d8128de9a3f8c13ec15e2aeac62b87
parentd687f9bb40d978912ef8ddd531fa39e032de4c39 (diff)
Android: Added download models into build.gradle for android example
Change: 150471440
-rw-r--r--tensorflow/examples/android/BUILD2
-rw-r--r--tensorflow/examples/android/README.md10
-rw-r--r--tensorflow/examples/android/build.gradle38
-rw-r--r--tensorflow/examples/android/download-models.gradle64
4 files changed, 86 insertions, 28 deletions
diff --git a/tensorflow/examples/android/BUILD b/tensorflow/examples/android/BUILD
index 83b9dc49a3..c3206d81dc 100644
--- a/tensorflow/examples/android/BUILD
+++ b/tensorflow/examples/android/BUILD
@@ -86,6 +86,7 @@ android_binary(
],
)
+# LINT.IfChange
filegroup(
name = "external_assets",
srcs = [
@@ -94,6 +95,7 @@ filegroup(
"@stylize//:model_files",
],
)
+# LINT.ThenChange(//tensorflow/examples/android/download-models.gradle)
filegroup(
name = "all_files",
diff --git a/tensorflow/examples/android/README.md b/tensorflow/examples/android/README.md
index 3182b8f211..81a2a66617 100644
--- a/tensorflow/examples/android/README.md
+++ b/tensorflow/examples/android/README.md
@@ -112,10 +112,10 @@ The NDK API level may remain at 14.
The TensorFlow `GraphDef`s that contain the model definitions and weights
are not packaged in the repo because of their size. They are downloaded
automatically and packaged with the APK by Bazel via a new_http_archive defined
-in `WORKSPACE` during the build process.
+in `WORKSPACE` during the build process, and by Gradle via download-models.gradle.
-**Optional**: If you wish to place the models in your assets manually (E.g. for
-non-Bazel builds), remove all of the `model_files` entries from the `assets`
+**Optional**: If you wish to place the models in your assets manually,
+remove all of the `model_files` entries from the `assets`
list in `tensorflow_demo` found in the `[BUILD](BUILD)` file. Then download
and extract the archives yourself to the `assets` directory in the source tree:
@@ -131,6 +131,10 @@ done
This will extract the models and their associated metadata files to the local
assets/ directory.
+If you are using Gradle, make sure to remove download-models.gradle reference
+from build.gradle after your manually download models; otherwise gradle
+might download models again and overwrite your models.
+
##### Build
After editing your WORKSPACE file to update the SDK/NDK configuration,
diff --git a/tensorflow/examples/android/build.gradle b/tensorflow/examples/android/build.gradle
index ebfaf6539b..414a97713e 100644
--- a/tensorflow/examples/android/build.gradle
+++ b/tensorflow/examples/android/build.gradle
@@ -41,9 +41,6 @@ if (buildWithMake) {
// automatically.
def makeNdkRoot = System.getenv('NDK_ROOT')
-// Location of model files required as assets
-def externalModelData = '../../../bazel-tensorflow/external'
-
// If building with Bazel, this is the location of the bazel binary.
// NOTE: Bazel does not yet support building for Android on Windows,
// so in this case the Makefile build must be used as described above.
@@ -52,6 +49,10 @@ def bazelLocation = '/usr/local/bin/bazel'
project.buildDir = 'gradleBuild'
getProject().setBuildDir('gradleBuild')
+// import DownloadModels task
+project.ext.ASSET_DIR = projectDir.toString() + '/assets'
+project.ext.TMP_DIR = project.buildDir.toString() + '/downloads'
+
buildscript {
repositories {
jcenter()
@@ -95,7 +96,7 @@ android {
aidl.srcDirs = ['src']
renderscript.srcDirs = ['src']
res.srcDirs = ['res']
- assets.srcDirs = ['assets']
+ assets.srcDirs = [project.ext.ASSET_DIR]
jniLibs.srcDirs = ['libs']
}
@@ -127,10 +128,6 @@ task buildNativeMake(type: Exec) {
//, '-T' // Uncomment to skip protobuf and speed up subsequent builds.
}
-task buildExternalAssets(type: Exec) {
- commandLine bazelLocation, 'build', '//tensorflow/examples/android:external_assets'
- outputs.files(externalModelData)
-}
task copyNativeLibs(type: Copy) {
from demoLibPath
@@ -141,22 +138,13 @@ task copyNativeLibs(type: Copy) {
fileMode 0644
}
-task copyExternalAssets(type: Copy) {
- from file(externalModelData).listFiles()
- include '*.pb'
- include '*.txt'
- include 'thumbnails/*.jpg'
- into 'assets'
- fileMode 0644
- dependsOn buildExternalAssets
-}
-
-def copyTasks = [copyNativeLibs]
-if (!buildWithMake) {
- // copyExternalAssets uses bazel, so only run it when requested.
- copyTasks.add(copyExternalAssets)
-}
-assemble.dependsOn copyTasks
+assemble.dependsOn copyNativeLibs
afterEvaluate {
- assembleDebug.dependsOn copyTasks
+ assembleDebug.dependsOn copyNativeLibs
+ assembleRelease.dependsOn copyNativeLibs
}
+
+// Download default models; if you wish to use your own models then
+// place them in the "assets" directory and comment out this line.
+apply from: "download-models.gradle"
+
diff --git a/tensorflow/examples/android/download-models.gradle b/tensorflow/examples/android/download-models.gradle
new file mode 100644
index 0000000000..d60df15f08
--- /dev/null
+++ b/tensorflow/examples/android/download-models.gradle
@@ -0,0 +1,64 @@
+/*
+ * download-models.gradle
+ * Downloads model files from ${MODEL_URL} into application's asset folder
+ * Input:
+ * project.ext.TMP_DIR: absolute path to hold downloaded zip files
+ * project.ext.ASSET_DIR: absolute path to save unzipped model files
+ * Output:
+ * 3 model files will be downloaded into given folder of ext.ASSET_DIR
+ */
+// hard coded model files
+// LINT.IfChange
+def models = ['inception5h.zip',
+ 'mobile_multibox_v1a.zip',
+ 'stylize_v1.zip']
+// LINT.ThenChange(//tensorflow/examples/android/BUILD)
+
+// Root URL for model archives
+def MODEL_URL = 'https://storage.googleapis.com/download.tensorflow.org/models'
+
+buildscript {
+ repositories {
+ jcenter()
+ }
+ dependencies {
+ classpath 'de.undercouch:gradle-download-task:3.2.0'
+ }
+}
+
+import de.undercouch.gradle.tasks.download.Download
+task downloadFile(type: Download){
+ for (f in models) {
+ src "${MODEL_URL}/" + f
+ }
+ dest new File(project.ext.TMP_DIR)
+ overwrite true
+}
+
+task extractModels(type: Copy) {
+ for (f in models) {
+ from zipTree(project.ext.TMP_DIR + '/' + f)
+ }
+
+ into file(project.ext.ASSET_DIR)
+ fileMode 0644
+ exclude '**/LICENSE'
+
+ dependsOn downloadFile
+}
+
+afterEvaluate {
+ // if models are not available, download & unzip them
+ def needDownload = false
+ for (f in models) {
+ if (!(new File(project.ext.TMP_DIR + '/' + f)).exists()) {
+ needDownload = true
+ }
+ }
+
+ if (needDownload) {
+ assembleDebug.dependsOn extractModels
+ assembleRelease.dependsOn extractModels
+ }
+}
+