aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/android/download-models.gradle
blob: d3b67eab52bfbcf006755bb36396a0d71fb66f77 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
/*
 * download-models.gradle
 *     Downloads model files from ${MODEL_URL} into application's asset folder
 * Input:
 *     project.ext.TMP_DIR: absolute path to hold downloaded zip files
 *     project.ext.ASSET_DIR: absolute path to save unzipped model files
 * Output:
 *     3 model files will be downloaded into given folder of ext.ASSET_DIR
 */
// hard coded model files
// LINT.IfChange
def models = ['inception_v1.zip',
              'object_detection/ssd_mobilenet_v1_android_export.zip',
              'stylize_v1.zip',
              'speech_commands_conv_actions.zip']
// LINT.ThenChange(//tensorflow/examples/android/BUILD)

// Root URL for model archives
def MODEL_URL = 'https://storage.googleapis.com/download.tensorflow.org/models'

buildscript {
    repositories {
        jcenter()
    }
    dependencies {
        classpath 'de.undercouch:gradle-download-task:3.2.0'
    }
}

import de.undercouch.gradle.tasks.download.Download
task downloadFile(type: Download){
    for (f in models) {
        src "${MODEL_URL}/" + f
    }
    dest new File(project.ext.TMP_DIR)
    overwrite true
}

task extractModels(type: Copy) {
    for (f in models) {
        def localFile = f.split("/")[-1]
        from zipTree(project.ext.TMP_DIR + '/' + localFile)
    }

    into file(project.ext.ASSET_DIR)
    fileMode  0644
    exclude '**/LICENSE'

    def needDownload = false
    for (f in models) {
        def localFile = f.split("/")[-1]
        if (!(new File(project.ext.TMP_DIR + '/' + localFile)).exists()) {
            needDownload = true
        }
    }

    if (needDownload) {
        dependsOn downloadFile
    }
}

tasks.whenTaskAdded { task ->
    if (task.name == 'assembleDebug') {
        task.dependsOn 'extractModels'
    }
    if (task.name == 'assembleRelease') {
        task.dependsOn 'extractModels'
    }
}