diff options
author | Suharsh Sivakumar <suharshs@google.com> | 2018-08-31 12:16:27 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-31 12:20:45 -0700 |
commit | 2b6e2f166e0e25984c32d3df48ba848c7f08b96b (patch) | |
tree | 4be2c6d942191018f3708e95457761c5127d903f | |
parent | cda5ea80b86909fd20ff8a0f5ba914c5c03b876f (diff) |
Introduce post_training_quantize flag and deprecate quantize_weights flag.
PiperOrigin-RevId: 211124183
-rw-r--r-- | tensorflow/contrib/lite/python/convert.py | 10 | ||||
-rw-r--r-- | tensorflow/contrib/lite/python/lite.py | 10 | ||||
-rw-r--r-- | tensorflow/contrib/lite/python/lite_test.py | 12 | ||||
-rw-r--r-- | tensorflow/contrib/lite/python/tflite_convert.py | 25 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/args.h | 3 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md | 8 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/toco_cmdline_flags.cc | 27 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/toco_flags.proto | 8 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/toco_tooling.cc | 3 |
9 files changed, 69 insertions, 37 deletions
diff --git a/tensorflow/contrib/lite/python/convert.py b/tensorflow/contrib/lite/python/convert.py index 69a3d562b3..1c5516ae7c 100644 --- a/tensorflow/contrib/lite/python/convert.py +++ b/tensorflow/contrib/lite/python/convert.py @@ -126,7 +126,7 @@ def build_toco_convert_protos(input_tensors, reorder_across_fake_quant=False, allow_custom_ops=False, change_concat_input_ranges=False, - quantize_weights=False, + post_training_quantize=False, dump_graphviz_dir=None, dump_graphviz_video=False): """Builds protocol buffers describing a conversion of a model using TOCO. @@ -173,9 +173,9 @@ def build_toco_convert_protos(input_tensors, change_concat_input_ranges: Boolean to change behavior of min/max ranges for inputs and outputs of the concat operator for quantized models. Changes the ranges of concat operator overlap when true. (default False) - quantize_weights: Boolean indicating whether to store weights as quantized - weights followed by dequantize operations. Computation is still done in - float, but reduces model size (at the cost of accuracy and latency). + post_training_quantize: Boolean indicating whether to quantize the weights + of the converted float model. Model size will be reduced and there will be + latency improvements (at the cost of accuracy). (default False) dump_graphviz_dir: Full filepath of folder to dump the graphs at various stages of processing GraphViz .dot files. Preferred over @@ -204,7 +204,7 @@ def build_toco_convert_protos(input_tensors, toco.drop_control_dependency = drop_control_dependency toco.reorder_across_fake_quant = reorder_across_fake_quant toco.allow_custom_ops = allow_custom_ops - toco.quantize_weights = quantize_weights + toco.post_training_quantize = post_training_quantize if default_ranges_stats: toco.default_ranges_min = default_ranges_stats[0] toco.default_ranges_max = default_ranges_stats[1] diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py index 80cbb12825..2de97fec86 100644 --- a/tensorflow/contrib/lite/python/lite.py +++ b/tensorflow/contrib/lite/python/lite.py @@ -102,9 +102,9 @@ class TocoConverter(object): created for any op that is unknown. The developer will need to provide these to the TensorFlow Lite runtime with a custom resolver. (default False) - quantize_weights: Boolean indicating whether to store weights as quantized - weights followed by dequantize operations. Computation is still done in - float, but reduces model size (at the cost of accuracy and latency). + post_training_quantize: Boolean indicating whether to quantize the weights + of the converted float model. Model size will be reduced and there will be + latency improvements (at the cost of accuracy). (default False) dump_graphviz_dir: Full filepath of folder to dump the graphs at various stages of processing GraphViz .dot files. Preferred over @@ -175,7 +175,7 @@ class TocoConverter(object): self.reorder_across_fake_quant = False self.change_concat_input_ranges = False self.allow_custom_ops = False - self.quantize_weights = False + self.post_training_quantize = False self.dump_graphviz_dir = None self.dump_graphviz_video = False @@ -425,7 +425,7 @@ class TocoConverter(object): "reorder_across_fake_quant": self.reorder_across_fake_quant, "change_concat_input_ranges": self.change_concat_input_ranges, "allow_custom_ops": self.allow_custom_ops, - "quantize_weights": self.quantize_weights, + "post_training_quantize": self.post_training_quantize, "dump_graphviz_dir": self.dump_graphviz_dir, "dump_graphviz_video": self.dump_graphviz_video } diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py index d004c3ecca..1c94ba605a 100644 --- a/tensorflow/contrib/lite/python/lite_test.py +++ b/tensorflow/contrib/lite/python/lite_test.py @@ -372,7 +372,7 @@ class FromSessionTest(test_util.TensorFlowTestCase): self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) self.assertTrue(output_details[0]['quantization'][0] > 0) # scale - def testQuantizeWeights(self): + def testPostTrainingQuantize(self): np.random.seed(0) # We need the tensor to have more than 1024 elements for quantize_weights # to kick in. Thus, the [33, 33] shape. @@ -393,14 +393,14 @@ class FromSessionTest(test_util.TensorFlowTestCase): self.assertTrue(float_tflite) # Convert quantized weights model. - quantized_weights_converter = lite.TocoConverter.from_session( + quantized_converter = lite.TocoConverter.from_session( sess, [in_tensor_1], [out_tensor]) - quantized_weights_converter.quantize_weights = True - quantized_weights_tflite = quantized_weights_converter.convert() - self.assertTrue(quantized_weights_tflite) + quantized_converter.post_training_quantize = True + quantized_tflite = quantized_converter.convert() + self.assertTrue(quantized_tflite) # Ensure that the quantized weights tflite model is smaller. - self.assertTrue(len(quantized_weights_tflite) < len(float_tflite)) + self.assertTrue(len(quantized_tflite) < len(float_tflite)) class FromFrozenGraphFile(test_util.TensorFlowTestCase): diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py index dc078ffd21..cc08ed3fe9 100644 --- a/tensorflow/contrib/lite/python/tflite_convert.py +++ b/tensorflow/contrib/lite/python/tflite_convert.py @@ -142,11 +142,14 @@ def _convert_model(flags): flags.change_concat_input_ranges == "TRUE") if flags.allow_custom_ops: converter.allow_custom_ops = flags.allow_custom_ops - if flags.quantize_weights: + + if flags.post_training_quantize: + converter.post_training_quantize = flags.post_training_quantize if flags.inference_type == lite_constants.QUANTIZED_UINT8: - raise ValueError("--quantized_weights is not supported with " - "--inference_type=QUANTIZED_UINT8") - converter.quantize_weights = flags.quantize_weights + print("--post_training_quantize quantizes a graph of inference_type " + "FLOAT. Overriding inference type QUANTIZED_UINT8 to FLOAT.") + converter.inference_type = lite_constants.FLOAT + if flags.dump_graphviz_dir: converter.dump_graphviz_dir = flags.dump_graphviz_dir if flags.dump_graphviz_video: @@ -318,12 +321,20 @@ def run_main(_): help=("Default value for max bound of min/max range values used for all " "arrays without a specified range, Intended for experimenting with " "quantization via \"dummy quantization\". (default None)")) + # quantize_weights is DEPRECATED. parser.add_argument( "--quantize_weights", + dest="post_training_quantize", + action="store_true", + help=argparse.SUPPRESS) + parser.add_argument( + "--post_training_quantize", + dest="post_training_quantize", action="store_true", - help=("Store float weights as quantized weights followed by dequantize " - "operations. Inference is still done in FLOAT, but reduces model " - "size (at the cost of accuracy and latency).")) + help=( + "Boolean indicating whether to quantize the weights of the " + "converted float model. Model size will be reduced and there will " + "be latency improvements (at the cost of accuracy). (default False)")) # Graph manipulation flags. parser.add_argument( diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h index aef35ad490..84f71dc7a7 100644 --- a/tensorflow/contrib/lite/toco/args.h +++ b/tensorflow/contrib/lite/toco/args.h @@ -236,8 +236,9 @@ struct ParsedTocoFlags { Arg<bool> drop_fake_quant = Arg<bool>(false); Arg<bool> reorder_across_fake_quant = Arg<bool>(false); Arg<bool> allow_custom_ops = Arg<bool>(false); - Arg<bool> quantize_weights = Arg<bool>(false); + Arg<bool> post_training_quantize = Arg<bool>(false); // Deprecated flags + Arg<bool> quantize_weights = Arg<bool>(false); Arg<string> input_type; Arg<string> input_types; Arg<bool> debug_disable_recurrent_cell_fusion = Arg<bool>(false); diff --git a/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md b/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md index 1de32f9977..00bc8d4ccb 100644 --- a/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md +++ b/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md @@ -149,10 +149,10 @@ have. true, custom ops are created for any op that is unknown. The developer will need to provide these to the TensorFlow Lite runtime with a custom resolver. -* `--quantize_weights`. Type: boolean. Default: False. Indicates whether to - store weights as quantized weights followed by dequantize operations. - Computation is still done in float, but reduces model size (at the cost of - accuracy and latency). +* `--post_training_quantize`. Type: boolean. Default: False. Boolean + indicating whether to quantize the weights of the converted float model. + Model size will be reduced and there will be latency improvements (at the + cost of accuracy). ## Logging flags diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc index c6d0a03452..f83a290195 100644 --- a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc +++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc @@ -160,10 +160,12 @@ bool ParseTocoFlagsFromCommandLineFlags( "Ignored if the output format is not TFLite."), Flag("quantize_weights", parsed_flags.quantize_weights.bind(), parsed_flags.quantize_weights.default_value(), - "Store weights as quantized weights followed by dequantize " - "operations. Computation is still done in float, but reduces model " - "size (at the cost of accuracy and latency)."), - }; + "Deprecated. Please use --post_training_quantize instead."), + Flag("post_training_quantize", parsed_flags.post_training_quantize.bind(), + parsed_flags.post_training_quantize.default_value(), + "Boolean indicating whether to quantize the weights of the " + "converted float model. Model size will be reduced and there will " + "be latency improvements (at the cost of accuracy).")}; bool asked_for_help = *argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help")); if (asked_for_help) { @@ -257,6 +259,7 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags, READ_TOCO_FLAG(dedupe_array_min_size_bytes, FlagRequirement::kNone); READ_TOCO_FLAG(split_tflite_lstm_inputs, FlagRequirement::kNone); READ_TOCO_FLAG(quantize_weights, FlagRequirement::kNone); + READ_TOCO_FLAG(post_training_quantize, FlagRequirement::kNone); // Deprecated flag handling. if (parsed_toco_flags.input_type.specified()) { @@ -291,9 +294,19 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags, toco_flags->set_inference_input_type(input_type); } if (parsed_toco_flags.quantize_weights.value()) { - QCHECK_NE(toco_flags->inference_type(), IODataType::QUANTIZED_UINT8) - << "quantize_weights is not supported with inference_type " - "QUANTIZED_UINT8."; + LOG(WARNING) + << "--quantize_weights is deprecated. Falling back to " + "--post_training_quantize. Please switch --post_training_quantize."; + toco_flags->set_post_training_quantize( + parsed_toco_flags.quantize_weights.value()); + } + if (parsed_toco_flags.quantize_weights.value()) { + if (toco_flags->inference_type() == IODataType::QUANTIZED_UINT8) { + LOG(WARNING) + << "--post_training_quantize quantizes a graph of inference_type " + "FLOAT. Overriding inference type QUANTIZED_UINT8 to FLOAT."; + toco_flags->set_inference_type(IODataType::FLOAT); + } } #undef READ_TOCO_FLAG diff --git a/tensorflow/contrib/lite/toco/toco_flags.proto b/tensorflow/contrib/lite/toco/toco_flags.proto index b4a9870d58..c1dd621429 100644 --- a/tensorflow/contrib/lite/toco/toco_flags.proto +++ b/tensorflow/contrib/lite/toco/toco_flags.proto @@ -37,7 +37,7 @@ enum FileFormat { // of as properties of models, instead describing how models are to be // processed in the context of the present tooling job. // -// Next ID to use: 26. +// Next ID to use: 27. message TocoFlags { // Input file format optional FileFormat input_format = 1; @@ -173,6 +173,7 @@ message TocoFlags { // Store weights as quantized weights followed by dequantize operations. // Computation is still done in float, but reduces model size (at the cost of // accuracy and latency). + // DEPRECATED: Please use post_training_quantize instead. optional bool quantize_weights = 20 [default = false]; // Full filepath of folder to dump the graphs at various stages of processing @@ -183,4 +184,9 @@ message TocoFlags { // Boolean indicating whether to dump the graph after every graph // transformation. optional bool dump_graphviz_include_video = 25; + + // Boolean indicating whether to quantize the weights of the converted float + // model. Model size will be reduced and there will be latency improvements + // (at the cost of accuracy). + optional bool post_training_quantize = 26 [default = false]; } diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc index 243d0dabdb..7db7acb44d 100644 --- a/tensorflow/contrib/lite/toco/toco_tooling.cc +++ b/tensorflow/contrib/lite/toco/toco_tooling.cc @@ -399,7 +399,8 @@ void Export(const TocoFlags& toco_flags, const Model& model, break; case TFLITE: toco::tflite::Export(model, allow_custom_ops, - toco_flags.quantize_weights(), output_file_contents); + toco_flags.post_training_quantize(), + output_file_contents); break; case GRAPHVIZ_DOT: DumpGraphviz(model, output_file_contents); |