aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
blob: cff79776bc787ea8c9b0afa8cce556a175453bd0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <string>
#include <vector>

#include "absl/strings/numbers.h"
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
#include "absl/strings/strip.h"
#include "absl/types/optional.h"
#include "tensorflow/contrib/lite/toco/toco_cmdline_flags.h"
#include "tensorflow/contrib/lite/toco/toco_port.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/command_line_flags.h"

namespace toco {

bool ParseTocoFlagsFromCommandLineFlags(
    int* argc, char* argv[], string* msg,
    ParsedTocoFlags* parsed_toco_flags_ptr) {
  using tensorflow::Flag;
  ParsedTocoFlags& parsed_flags = *parsed_toco_flags_ptr;
  std::vector<tensorflow::Flag> flags = {
      Flag("input_file", parsed_flags.input_file.bind(),
           parsed_flags.input_file.default_value(),
           "Input file (model of any supported format). For Protobuf "
           "formats, both text and binary are supported regardless of file "
           "extension."),
      Flag("savedmodel_directory", parsed_flags.savedmodel_directory.bind(),
           parsed_flags.savedmodel_directory.default_value(),
           "Deprecated. Full path to the directory containing the SavedModel."),
      Flag("output_file", parsed_flags.output_file.bind(),
           parsed_flags.output_file.default_value(),
           "Output file. "
           "For Protobuf formats, the binary format will be used."),
      Flag("input_format", parsed_flags.input_format.bind(),
           parsed_flags.input_format.default_value(),
           "Input file format. One of: TENSORFLOW_GRAPHDEF, TFLITE."),
      Flag("output_format", parsed_flags.output_format.bind(),
           parsed_flags.output_format.default_value(),
           "Output file format. "
           "One of TENSORFLOW_GRAPHDEF, TFLITE, GRAPHVIZ_DOT."),
      Flag("savedmodel_tagset", parsed_flags.savedmodel_tagset.bind(),
           parsed_flags.savedmodel_tagset.default_value(),
           "Deprecated. Comma-separated set of tags identifying the "
           "MetaGraphDef within the SavedModel to analyze. All tags in the tag "
           "set must be specified."),
      Flag("default_ranges_min", parsed_flags.default_ranges_min.bind(),
           parsed_flags.default_ranges_min.default_value(),
           "If defined, will be used as the default value for the min bound "
           "of min/max ranges used for quantization of uint8 arrays."),
      Flag("default_ranges_max", parsed_flags.default_ranges_max.bind(),
           parsed_flags.default_ranges_max.default_value(),
           "If defined, will be used as the default value for the max bound "
           "of min/max ranges used for quantization of uint8 arrays."),
      Flag("default_int16_ranges_min",
           parsed_flags.default_int16_ranges_min.bind(),
           parsed_flags.default_int16_ranges_min.default_value(),
           "If defined, will be used as the default value for the min bound "
           "of min/max ranges used for quantization of int16 arrays."),
      Flag("default_int16_ranges_max",
           parsed_flags.default_int16_ranges_max.bind(),
           parsed_flags.default_int16_ranges_max.default_value(),
           "If defined, will be used as the default value for the max bound "
           "of min/max ranges used for quantization of int16 arrays."),
      Flag("inference_type", parsed_flags.inference_type.bind(),
           parsed_flags.inference_type.default_value(),
           "Target data type of arrays in the output file (for input_arrays, "
           "this may be overridden by inference_input_type). "
           "One of FLOAT, QUANTIZED_UINT8."),
      Flag("inference_input_type", parsed_flags.inference_input_type.bind(),
           parsed_flags.inference_input_type.default_value(),
           "Target data type of input arrays. "
           "If not specified, inference_type is used. "
           "One of FLOAT, QUANTIZED_UINT8."),
      Flag("input_type", parsed_flags.input_type.bind(),
           parsed_flags.input_type.default_value(),
           "Deprecated ambiguous flag that set both --input_data_types and "
           "--inference_input_type."),
      Flag("input_types", parsed_flags.input_types.bind(),
           parsed_flags.input_types.default_value(),
           "Deprecated ambiguous flag that set both --input_data_types and "
           "--inference_input_type. Was meant to be a "
           "comma-separated list, but this was deprecated before "
           "multiple-input-types was ever properly supported."),

      Flag("drop_fake_quant", parsed_flags.drop_fake_quant.bind(),
           parsed_flags.drop_fake_quant.default_value(),
           "Ignore and discard FakeQuant nodes. For instance, to "
           "generate plain float code without fake-quantization from a "
           "quantized graph."),
      Flag(
          "reorder_across_fake_quant",
          parsed_flags.reorder_across_fake_quant.bind(),
          parsed_flags.reorder_across_fake_quant.default_value(),
          "Normally, FakeQuant nodes must be strict boundaries for graph "
          "transformations, in order to ensure that quantized inference has "
          "the exact same arithmetic behavior as quantized training --- which "
          "is the whole point of quantized training and of FakeQuant nodes in "
          "the first place. "
          "However, that entails subtle requirements on where exactly "
          "FakeQuant nodes must be placed in the graph. Some quantized graphs "
          "have FakeQuant nodes at unexpected locations, that prevent graph "
          "transformations that are necessary in order to generate inference "
          "code for these graphs. Such graphs should be fixed, but as a "
          "temporary work-around, setting this reorder_across_fake_quant flag "
          "allows TOCO to perform necessary graph transformaitons on them, "
          "at the cost of no longer faithfully matching inference and training "
          "arithmetic."),
      Flag("allow_custom_ops", parsed_flags.allow_custom_ops.bind(),
           parsed_flags.allow_custom_ops.default_value(),
           "If true, allow TOCO to create TF Lite Custom operators for all the "
           "unsupported TensorFlow ops."),
      Flag(
          "drop_control_dependency",
          parsed_flags.drop_control_dependency.bind(),
          parsed_flags.drop_control_dependency.default_value(),
          "If true, ignore control dependency requirements in input TensorFlow "
          "GraphDef. Otherwise an error will be raised upon control dependency "
          "inputs."),
      Flag("debug_disable_recurrent_cell_fusion",
           parsed_flags.debug_disable_recurrent_cell_fusion.bind(),
           parsed_flags.debug_disable_recurrent_cell_fusion.default_value(),
           "If true, disable fusion of known identifiable cell subgraphs into "
           "cells. This includes, for example, specific forms of LSTM cell."),
      Flag("propagate_fake_quant_num_bits",
           parsed_flags.propagate_fake_quant_num_bits.bind(),
           parsed_flags.propagate_fake_quant_num_bits.default_value(),
           "If true, use FakeQuant* operator num_bits attributes to adjust "
           "array data_types."),
      Flag("allow_nudging_weights_to_use_fast_gemm_kernel",
           parsed_flags.allow_nudging_weights_to_use_fast_gemm_kernel.bind(),
           parsed_flags.allow_nudging_weights_to_use_fast_gemm_kernel
               .default_value(),
           "Some fast uint8 GEMM kernels require uint8 weights to avoid the "
           "value 0. This flag allows nudging them to 1 to allow proceeding, "
           "with moderate inaccuracy."),
      Flag("dedupe_array_min_size_bytes",
           parsed_flags.dedupe_array_min_size_bytes.bind(),
           parsed_flags.dedupe_array_min_size_bytes.default_value(),
           "Minimum size of constant arrays to deduplicate; arrays smaller "
           "will not be deduplicated."),
      Flag("split_tflite_lstm_inputs",
           parsed_flags.split_tflite_lstm_inputs.bind(),
           parsed_flags.split_tflite_lstm_inputs.default_value(),
           "Split the LSTM inputs from 5 tensors to 18 tensors for TFLite. "
           "Ignored if the output format is not TFLite."),
      Flag("quantize_weights", parsed_flags.quantize_weights.bind(),
           parsed_flags.quantize_weights.default_value(),
           "Deprecated. Please use --post_training_quantize instead."),
      Flag("post_training_quantize", parsed_flags.post_training_quantize.bind(),
           parsed_flags.post_training_quantize.default_value(),
           "Boolean indicating whether to quantize the weights of the "
           "converted float model. Model size will be reduced and there will "
           "be latency improvements (at the cost of accuracy)."),
      // WARNING: Experimental interface, subject to change
      Flag("allow_flex_ops", parsed_flags.allow_flex_ops.bind(),
           parsed_flags.allow_flex_ops.default_value(), ""),
      // WARNING: Experimental interface, subject to change
      Flag("force_flex_ops", parsed_flags.force_flex_ops.bind(),
           parsed_flags.force_flex_ops.default_value(), "")};
  bool asked_for_help =
      *argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
  if (asked_for_help) {
    *msg += tensorflow::Flags::Usage(argv[0], flags);
    return false;
  } else {
    return tensorflow::Flags::Parse(argc, argv, flags);
  }
}

namespace {

// Defines the requirements for a given flag. kUseDefault means the default
// should be used in cases where the value isn't specified by the user.
enum class FlagRequirement {
  kNone,
  kMustBeSpecified,
  kMustNotBeSpecified,
  kUseDefault,
};

// Enforces the FlagRequirements are met for a given flag.
template <typename T>
void EnforceFlagRequirement(const T& flag, const string& flag_name,
                            FlagRequirement requirement) {
  if (requirement == FlagRequirement::kMustBeSpecified) {
    QCHECK(flag.specified()) << "Missing required flag " << flag_name;
  }
  if (requirement == FlagRequirement::kMustNotBeSpecified) {
    QCHECK(!flag.specified())
        << "Given other flags, this flag should not have been specified: "
        << flag_name;
  }
}

// Gets the value from the flag if specified. Returns default if the
// FlagRequirement is kUseDefault.
template <typename T>
absl::optional<T> GetFlagValue(const Arg<T>& flag,
                               FlagRequirement requirement) {
  if (flag.specified()) return flag.value();
  if (requirement == FlagRequirement::kUseDefault) return flag.default_value();
  return absl::optional<T>();
}

}  // namespace

void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags,
                                       TocoFlags* toco_flags) {
  namespace port = toco::port;
  port::CheckInitGoogleIsDone("InitGoogle is not done yet");

#define READ_TOCO_FLAG(name, requirement)                                \
  do {                                                                   \
    EnforceFlagRequirement(parsed_toco_flags.name, #name, requirement);  \
    auto flag_value = GetFlagValue(parsed_toco_flags.name, requirement); \
    if (flag_value.has_value()) {                                        \
      toco_flags->set_##name(flag_value.value());                        \
    }                                                                    \
  } while (false)

#define PARSE_TOCO_FLAG(Type, name, requirement)                         \
  do {                                                                   \
    EnforceFlagRequirement(parsed_toco_flags.name, #name, requirement);  \
    auto flag_value = GetFlagValue(parsed_toco_flags.name, requirement); \
    if (flag_value.has_value()) {                                        \
      Type x;                                                            \
      QCHECK(Type##_Parse(flag_value.value(), &x))                       \
          << "Unrecognized " << #Type << " value "                       \
          << parsed_toco_flags.name.value();                             \
      toco_flags->set_##name(x);                                         \
    }                                                                    \
  } while (false)

  PARSE_TOCO_FLAG(FileFormat, input_format, FlagRequirement::kUseDefault);
  PARSE_TOCO_FLAG(FileFormat, output_format, FlagRequirement::kUseDefault);
  PARSE_TOCO_FLAG(IODataType, inference_type, FlagRequirement::kNone);
  PARSE_TOCO_FLAG(IODataType, inference_input_type, FlagRequirement::kNone);
  READ_TOCO_FLAG(default_ranges_min, FlagRequirement::kNone);
  READ_TOCO_FLAG(default_ranges_max, FlagRequirement::kNone);
  READ_TOCO_FLAG(default_int16_ranges_min, FlagRequirement::kNone);
  READ_TOCO_FLAG(default_int16_ranges_max, FlagRequirement::kNone);
  READ_TOCO_FLAG(drop_fake_quant, FlagRequirement::kNone);
  READ_TOCO_FLAG(reorder_across_fake_quant, FlagRequirement::kNone);
  READ_TOCO_FLAG(allow_custom_ops, FlagRequirement::kNone);
  READ_TOCO_FLAG(drop_control_dependency, FlagRequirement::kNone);
  READ_TOCO_FLAG(debug_disable_recurrent_cell_fusion, FlagRequirement::kNone);
  READ_TOCO_FLAG(propagate_fake_quant_num_bits, FlagRequirement::kNone);
  READ_TOCO_FLAG(allow_nudging_weights_to_use_fast_gemm_kernel,
                 FlagRequirement::kNone);
  READ_TOCO_FLAG(dedupe_array_min_size_bytes, FlagRequirement::kNone);
  READ_TOCO_FLAG(split_tflite_lstm_inputs, FlagRequirement::kNone);
  READ_TOCO_FLAG(quantize_weights, FlagRequirement::kNone);
  READ_TOCO_FLAG(post_training_quantize, FlagRequirement::kNone);
  READ_TOCO_FLAG(allow_flex_ops, FlagRequirement::kNone);
  READ_TOCO_FLAG(force_flex_ops, FlagRequirement::kNone);

  if (parsed_toco_flags.force_flex_ops.value() &&
      !parsed_toco_flags.allow_flex_ops.value()) {
    // TODO(ycling): Consider to enforce `allow_flex_ops` when
    // `force_flex_ops` is true.
    LOG(WARNING) << "--force_flex_ops should always be used with "
                    "--allow_flex_ops.";
  }

  // Deprecated flag handling.
  if (parsed_toco_flags.input_type.specified()) {
    LOG(WARNING)
        << "--input_type is deprecated. It was an ambiguous flag that set both "
           "--input_data_types and --inference_input_type. If you are trying "
           "to complement the input file with information about the type of "
           "input arrays, use --input_data_type. If you are trying to control "
           "the quantization/dequantization of real-numbers input arrays in "
           "the output file, use --inference_input_type.";
    toco::IODataType input_type;
    QCHECK(toco::IODataType_Parse(parsed_toco_flags.input_type.value(),
                                  &input_type));
    toco_flags->set_inference_input_type(input_type);
  }
  if (parsed_toco_flags.input_types.specified()) {
    LOG(WARNING)
        << "--input_types is deprecated. It was an ambiguous flag that set "
           "both --input_data_types and --inference_input_type. If you are "
           "trying to complement the input file with information about the "
           "type of input arrays, use --input_data_type. If you are trying to "
           "control the quantization/dequantization of real-numbers input "
           "arrays in the output file, use --inference_input_type.";
    std::vector<string> input_types =
        absl::StrSplit(parsed_toco_flags.input_types.value(), ',');
    QCHECK(!input_types.empty());
    for (int i = 1; i < input_types.size(); i++) {
      QCHECK_EQ(input_types[i], input_types[0]);
    }
    toco::IODataType input_type;
    QCHECK(toco::IODataType_Parse(input_types[0], &input_type));
    toco_flags->set_inference_input_type(input_type);
  }
  if (parsed_toco_flags.quantize_weights.value()) {
    LOG(WARNING)
        << "--quantize_weights is deprecated. Falling back to "
           "--post_training_quantize. Please switch --post_training_quantize.";
    toco_flags->set_post_training_quantize(
        parsed_toco_flags.quantize_weights.value());
  }
  if (parsed_toco_flags.quantize_weights.value()) {
    if (toco_flags->inference_type() == IODataType::QUANTIZED_UINT8) {
      LOG(WARNING)
          << "--post_training_quantize quantizes a graph of inference_type "
             "FLOAT. Overriding inference type QUANTIZED_UINT8 to FLOAT.";
      toco_flags->set_inference_type(IODataType::FLOAT);
    }
  }

#undef READ_TOCO_FLAG
#undef PARSE_TOCO_FLAG
}
}  // namespace toco