aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/build_def.bzl
blob: 5543acc1f5dabaa8a54ec4d1f2027bc66a00f6db (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
"""Generate Flatbuffer binary from json."""
load(
    "//tensorflow:tensorflow.bzl",
    "tf_cc_test",
)

def tflite_copts():
  """Defines compile time flags."""
  copts = [
      "-DFARMHASH_NO_CXX_STRING",
  ] + select({
          str(Label("//tensorflow:android_arm64")): [
              "-std=c++11",
              "-O3",
          ],
          str(Label("//tensorflow:android_arm")): [
              "-mfpu=neon",
              "-mfloat-abi=softfp",
              "-std=c++11",
              "-O3",
          ],
          str(Label("//tensorflow:android_x86")): [
              "-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK",
          ],
          str(Label("//tensorflow:ios_x86_64")): [
              "-msse4.1",
          ],
          "//conditions:default": [],
  }) + select({
      str(Label("//tensorflow:with_default_optimizations")): [],
      "//conditions:default": ["-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK"],
  })

  return copts

LINKER_SCRIPT = "//tensorflow/contrib/lite/java/src/main/native:version_script.lds"

def tflite_linkopts_unstripped():
  """Defines linker flags to reduce size of TFLite binary.

     These are useful when trying to investigate the relative size of the
     symbols in TFLite.

  Returns:
     a select object with proper linkopts
  """
  return select({
      "//tensorflow:android": [
          "-Wl,--no-export-dynamic", # Only inc syms referenced by dynamic obj.
          "-Wl,--exclude-libs,ALL",  # Exclude syms in all libs from auto export.
          "-Wl,--gc-sections", # Eliminate unused code and data.
          "-Wl,--as-needed", # Don't link unused libs.
      ],
      "//tensorflow/contrib/lite:mips": [],
      "//tensorflow/contrib/lite:mips64": [],
      "//conditions:default": [
          "-Wl,--icf=all",  # Identical code folding.
      ],
  })

def tflite_jni_linkopts_unstripped():
  """Defines linker flags to reduce size of TFLite binary with JNI.

     These are useful when trying to investigate the relative size of the
     symbols in TFLite.

  Returns:
     a select object with proper linkopts
  """
  return select({
      "//tensorflow:android": [
          "-Wl,--gc-sections", # Eliminate unused code and data.
          "-Wl,--as-needed", # Don't link unused libs.
      ],
      "//tensorflow/contrib/lite:mips": [],
      "//tensorflow/contrib/lite:mips64": [],
      "//conditions:default": [
          "-Wl,--icf=all",  # Identical code folding.
      ],
  })

def tflite_linkopts():
  """Defines linker flags to reduce size of TFLite binary."""
  return tflite_linkopts_unstripped() + select({
      "//tensorflow:android": [
          "-s",  # Omit symbol table.
      ],
      "//conditions:default": [],
  })

def tflite_jni_linkopts():
  """Defines linker flags to reduce size of TFLite binary with JNI."""
  return tflite_jni_linkopts_unstripped() + select({
      "//tensorflow:android": [
          "-s",  # Omit symbol table.
          "-latomic",  # Required for some uses of ISO C++11 <atomic> in x86.
      ],
      "//conditions:default": [],
  })

def tflite_jni_binary(name,
                      copts=tflite_copts(),
                      linkopts=tflite_jni_linkopts(),
                      linkscript=LINKER_SCRIPT,
                      linkshared=1,
                      linkstatic=1,
                      deps=[]):
  """Builds a jni binary for TFLite."""
  linkopts = linkopts + [
      "-Wl,--version-script",  # Export only jni functions & classes.
      "$(location {})".format(linkscript),
  ]
  native.cc_binary(
      name=name,
      copts=copts,
      linkshared=linkshared,
      linkstatic=linkstatic,
      deps= deps + [linkscript],
      linkopts=linkopts)

def tf_to_tflite(name, src, options, out):
  """Convert a frozen tensorflow graphdef to TF Lite's flatbuffer.

  Args:
    name: Name of rule.
    src: name of the input graphdef file.
    options: options passed to TOCO.
    out: name of the output flatbuffer file.
  """

  toco_cmdline = " ".join([
      "//tensorflow/contrib/lite/toco:toco",
      "--input_format=TENSORFLOW_GRAPHDEF",
      "--output_format=TFLITE",
      ("--input_file=$(location %s)" % src),
      ("--output_file=$(location %s)" % out),
  ] + options )
  native.genrule(
      name = name,
      srcs=[src],
      outs=[out],
      cmd = toco_cmdline,
      tools= ["//tensorflow/contrib/lite/toco:toco"],
  )

def tflite_to_json(name, src, out):
  """Convert a TF Lite flatbuffer to JSON.

  Args:
    name: Name of rule.
    src: name of the input flatbuffer file.
    out: name of the output JSON file.
  """

  flatc = "@flatbuffers//:flatc"
  schema = "//tensorflow/contrib/lite/schema:schema.fbs"
  native.genrule(
      name = name,
      srcs = [schema, src],
      outs = [out],
      cmd = ("TMP=`mktemp`; cp $(location %s) $${TMP}.bin &&"  +
             "$(location %s) --raw-binary --strict-json -t" +
             " -o /tmp $(location %s) -- $${TMP}.bin &&" +
             "cp $${TMP}.json $(location %s)")
            % (src, flatc, schema, out),
      tools = [flatc],
  )

def json_to_tflite(name, src, out):
  """Convert a JSON file to TF Lite's flatbuffer.

  Args:
    name: Name of rule.
    src: name of the input JSON file.
    out: name of the output flatbuffer file.
  """

  flatc = "@flatbuffers//:flatc"
  schema = "//tensorflow/contrib/lite/schema:schema_fbs"
  native.genrule(
      name = name,
      srcs = [schema, src],
      outs = [out],
      cmd = ("TMP=`mktemp`; cp $(location %s) $${TMP}.json &&"  +
             "$(location %s) --raw-binary --unknown-json --allow-non-utf8 -b" +
             " -o /tmp $(location %s) $${TMP}.json &&" +
             "cp $${TMP}.bin $(location %s)")
      % (src, flatc, schema, out),
      tools = [flatc],
  )

# This is the master list of generated examples that will be made into tests. A
# function called make_XXX_tests() must also appear in generate_examples.py.
# Disable a test by commenting it out. If you do, add a link to a bug or issue.
def generated_test_models():
    return [
        "add",
        "arg_max",
        "avg_pool",
        "batch_to_space_nd",
        "concat",
        "constant",
        "control_dep",
        "conv",
        "depthwiseconv",
        "div",
        "equal",
        "exp",
        "expand_dims",
        "floor",
        "fully_connected",
        "fused_batch_norm",
        "gather",
        "global_batch_norm",
        "greater",
        "greater_equal",
        "sum",
        "l2norm",
        "l2_pool",
        "less",
        "less_equal",
        "local_response_norm",
        "log_softmax",
        "log",
        "lstm",
        "max_pool",
        "maximum",
        "mean",
        "minimum",
        "mul",
        "neg",
        "not_equal",
        "pad",
        "padv2",
        # "prelu",
        "pow",
        "relu",
        "relu1",
        "relu6",
        "reshape",
        "resize_bilinear",
        "rsqrt",
        "shape",
        "sigmoid",
        "sin",
        "slice",
        "softmax",
        "space_to_batch_nd",
        "space_to_depth",
        "sparse_to_dense",
        "split",
        "sqrt",
        "squeeze",
        "strided_slice",
        "strided_slice_1d_exhaustive",
        "sub",
        "tile",
        "topk",
        "transpose",
        "transpose_conv",
        "where",
    ]

def gen_zip_test(name, test_name, **kwargs):
  """Generate a zipped-example test and its dependent zip files.

  Args:
    name: Resulting cc_test target name
    test_name: Test targets this model. Comes from the list above.
    **kwargs: tf_cc_test kwargs.
  """
  gen_zipped_test_file(
      name = "zip_%s" % test_name,
      file = "%s.zip" % test_name,
  )
  tf_cc_test(name, **kwargs)

def gen_zipped_test_file(name, file):
  """Generate a zip file of tests by using :generate_examples.

  Args:
    name: Name of output. We will produce "`file`.files" as a target.
    file: The name of one of the generated_examples targets, e.g. "transpose"
  """
  toco = "//tensorflow/contrib/lite/toco:toco"
  native.genrule(
      name = file + ".files",
      cmd = ("$(locations :generate_examples) --toco $(locations %s) " % toco
             + " --zip_to_output " + file + " $(@D)"),
      outs = [file],
      tools = [
          ":generate_examples",
          toco,
      ],
  )

  native.filegroup(
      name = name,
      srcs = [file],
  )

def gen_selected_ops(name, model):
  """Generate the library that includes only used ops.

  Args:
    name: Name of the generated library.
    model: TFLite model to interpret.
  """
  out = name + "_registration.cc"
  tool = "//tensorflow/contrib/lite/tools:generate_op_registrations"
  tflite_path = "//tensorflow/contrib/lite"
  native.genrule(
      name = name,
      srcs = [model],
      outs = [out],
      cmd = ("$(location %s) --input_model=$(location %s) --output_registration=$(location %s) --tflite_path=%s")
      % (tool, model, out, tflite_path[2:]),
      tools = [tool],
  )