aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Suharsh Sivakumar <suharshs@google.com>2018-08-31 12:16:27 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-31 12:20:45 -0700
commit2b6e2f166e0e25984c32d3df48ba848c7f08b96b (patch)
tree4be2c6d942191018f3708e95457761c5127d903f
parentcda5ea80b86909fd20ff8a0f5ba914c5c03b876f (diff)
Introduce post_training_quantize flag and deprecate quantize_weights flag.
PiperOrigin-RevId: 211124183
-rw-r--r--tensorflow/contrib/lite/python/convert.py10
-rw-r--r--tensorflow/contrib/lite/python/lite.py10
-rw-r--r--tensorflow/contrib/lite/python/lite_test.py12
-rw-r--r--tensorflow/contrib/lite/python/tflite_convert.py25
-rw-r--r--tensorflow/contrib/lite/toco/args.h3
-rw-r--r--tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md8
-rw-r--r--tensorflow/contrib/lite/toco/toco_cmdline_flags.cc27
-rw-r--r--tensorflow/contrib/lite/toco/toco_flags.proto8
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.cc3
9 files changed, 69 insertions, 37 deletions
diff --git a/tensorflow/contrib/lite/python/convert.py b/tensorflow/contrib/lite/python/convert.py
index 69a3d562b3..1c5516ae7c 100644
--- a/tensorflow/contrib/lite/python/convert.py
+++ b/tensorflow/contrib/lite/python/convert.py
@@ -126,7 +126,7 @@ def build_toco_convert_protos(input_tensors,
reorder_across_fake_quant=False,
allow_custom_ops=False,
change_concat_input_ranges=False,
- quantize_weights=False,
+ post_training_quantize=False,
dump_graphviz_dir=None,
dump_graphviz_video=False):
"""Builds protocol buffers describing a conversion of a model using TOCO.
@@ -173,9 +173,9 @@ def build_toco_convert_protos(input_tensors,
change_concat_input_ranges: Boolean to change behavior of min/max ranges for
inputs and outputs of the concat operator for quantized models. Changes
the ranges of concat operator overlap when true. (default False)
- quantize_weights: Boolean indicating whether to store weights as quantized
- weights followed by dequantize operations. Computation is still done in
- float, but reduces model size (at the cost of accuracy and latency).
+ post_training_quantize: Boolean indicating whether to quantize the weights
+ of the converted float model. Model size will be reduced and there will be
+ latency improvements (at the cost of accuracy).
(default False)
dump_graphviz_dir: Full filepath of folder to dump the graphs at various
stages of processing GraphViz .dot files. Preferred over
@@ -204,7 +204,7 @@ def build_toco_convert_protos(input_tensors,
toco.drop_control_dependency = drop_control_dependency
toco.reorder_across_fake_quant = reorder_across_fake_quant
toco.allow_custom_ops = allow_custom_ops
- toco.quantize_weights = quantize_weights
+ toco.post_training_quantize = post_training_quantize
if default_ranges_stats:
toco.default_ranges_min = default_ranges_stats[0]
toco.default_ranges_max = default_ranges_stats[1]
diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py
index 80cbb12825..2de97fec86 100644
--- a/tensorflow/contrib/lite/python/lite.py
+++ b/tensorflow/contrib/lite/python/lite.py
@@ -102,9 +102,9 @@ class TocoConverter(object):
created for any op that is unknown. The developer will need to provide
these to the TensorFlow Lite runtime with a custom resolver.
(default False)
- quantize_weights: Boolean indicating whether to store weights as quantized
- weights followed by dequantize operations. Computation is still done in
- float, but reduces model size (at the cost of accuracy and latency).
+ post_training_quantize: Boolean indicating whether to quantize the weights
+ of the converted float model. Model size will be reduced and there will be
+ latency improvements (at the cost of accuracy).
(default False)
dump_graphviz_dir: Full filepath of folder to dump the graphs at various
stages of processing GraphViz .dot files. Preferred over
@@ -175,7 +175,7 @@ class TocoConverter(object):
self.reorder_across_fake_quant = False
self.change_concat_input_ranges = False
self.allow_custom_ops = False
- self.quantize_weights = False
+ self.post_training_quantize = False
self.dump_graphviz_dir = None
self.dump_graphviz_video = False
@@ -425,7 +425,7 @@ class TocoConverter(object):
"reorder_across_fake_quant": self.reorder_across_fake_quant,
"change_concat_input_ranges": self.change_concat_input_ranges,
"allow_custom_ops": self.allow_custom_ops,
- "quantize_weights": self.quantize_weights,
+ "post_training_quantize": self.post_training_quantize,
"dump_graphviz_dir": self.dump_graphviz_dir,
"dump_graphviz_video": self.dump_graphviz_video
}
diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py
index d004c3ecca..1c94ba605a 100644
--- a/tensorflow/contrib/lite/python/lite_test.py
+++ b/tensorflow/contrib/lite/python/lite_test.py
@@ -372,7 +372,7 @@ class FromSessionTest(test_util.TensorFlowTestCase):
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
self.assertTrue(output_details[0]['quantization'][0] > 0) # scale
- def testQuantizeWeights(self):
+ def testPostTrainingQuantize(self):
np.random.seed(0)
# We need the tensor to have more than 1024 elements for quantize_weights
# to kick in. Thus, the [33, 33] shape.
@@ -393,14 +393,14 @@ class FromSessionTest(test_util.TensorFlowTestCase):
self.assertTrue(float_tflite)
# Convert quantized weights model.
- quantized_weights_converter = lite.TocoConverter.from_session(
+ quantized_converter = lite.TocoConverter.from_session(
sess, [in_tensor_1], [out_tensor])
- quantized_weights_converter.quantize_weights = True
- quantized_weights_tflite = quantized_weights_converter.convert()
- self.assertTrue(quantized_weights_tflite)
+ quantized_converter.post_training_quantize = True
+ quantized_tflite = quantized_converter.convert()
+ self.assertTrue(quantized_tflite)
# Ensure that the quantized weights tflite model is smaller.
- self.assertTrue(len(quantized_weights_tflite) < len(float_tflite))
+ self.assertTrue(len(quantized_tflite) < len(float_tflite))
class FromFrozenGraphFile(test_util.TensorFlowTestCase):
diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py
index dc078ffd21..cc08ed3fe9 100644
--- a/tensorflow/contrib/lite/python/tflite_convert.py
+++ b/tensorflow/contrib/lite/python/tflite_convert.py
@@ -142,11 +142,14 @@ def _convert_model(flags):
flags.change_concat_input_ranges == "TRUE")
if flags.allow_custom_ops:
converter.allow_custom_ops = flags.allow_custom_ops
- if flags.quantize_weights:
+
+ if flags.post_training_quantize:
+ converter.post_training_quantize = flags.post_training_quantize
if flags.inference_type == lite_constants.QUANTIZED_UINT8:
- raise ValueError("--quantized_weights is not supported with "
- "--inference_type=QUANTIZED_UINT8")
- converter.quantize_weights = flags.quantize_weights
+ print("--post_training_quantize quantizes a graph of inference_type "
+ "FLOAT. Overriding inference type QUANTIZED_UINT8 to FLOAT.")
+ converter.inference_type = lite_constants.FLOAT
+
if flags.dump_graphviz_dir:
converter.dump_graphviz_dir = flags.dump_graphviz_dir
if flags.dump_graphviz_video:
@@ -318,12 +321,20 @@ def run_main(_):
help=("Default value for max bound of min/max range values used for all "
"arrays without a specified range, Intended for experimenting with "
"quantization via \"dummy quantization\". (default None)"))
+ # quantize_weights is DEPRECATED.
parser.add_argument(
"--quantize_weights",
+ dest="post_training_quantize",
+ action="store_true",
+ help=argparse.SUPPRESS)
+ parser.add_argument(
+ "--post_training_quantize",
+ dest="post_training_quantize",
action="store_true",
- help=("Store float weights as quantized weights followed by dequantize "
- "operations. Inference is still done in FLOAT, but reduces model "
- "size (at the cost of accuracy and latency)."))
+ help=(
+ "Boolean indicating whether to quantize the weights of the "
+ "converted float model. Model size will be reduced and there will "
+ "be latency improvements (at the cost of accuracy). (default False)"))
# Graph manipulation flags.
parser.add_argument(
diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h
index aef35ad490..84f71dc7a7 100644
--- a/tensorflow/contrib/lite/toco/args.h
+++ b/tensorflow/contrib/lite/toco/args.h
@@ -236,8 +236,9 @@ struct ParsedTocoFlags {
Arg<bool> drop_fake_quant = Arg<bool>(false);
Arg<bool> reorder_across_fake_quant = Arg<bool>(false);
Arg<bool> allow_custom_ops = Arg<bool>(false);
- Arg<bool> quantize_weights = Arg<bool>(false);
+ Arg<bool> post_training_quantize = Arg<bool>(false);
// Deprecated flags
+ Arg<bool> quantize_weights = Arg<bool>(false);
Arg<string> input_type;
Arg<string> input_types;
Arg<bool> debug_disable_recurrent_cell_fusion = Arg<bool>(false);
diff --git a/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md b/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md
index 1de32f9977..00bc8d4ccb 100644
--- a/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md
+++ b/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md
@@ -149,10 +149,10 @@ have.
true, custom ops are created for any op that is unknown. The developer will
need to provide these to the TensorFlow Lite runtime with a custom resolver.
-* `--quantize_weights`. Type: boolean. Default: False. Indicates whether to
- store weights as quantized weights followed by dequantize operations.
- Computation is still done in float, but reduces model size (at the cost of
- accuracy and latency).
+* `--post_training_quantize`. Type: boolean. Default: False. Boolean
+ indicating whether to quantize the weights of the converted float model.
+ Model size will be reduced and there will be latency improvements (at the
+ cost of accuracy).
## Logging flags
diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
index c6d0a03452..f83a290195 100644
--- a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
+++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
@@ -160,10 +160,12 @@ bool ParseTocoFlagsFromCommandLineFlags(
"Ignored if the output format is not TFLite."),
Flag("quantize_weights", parsed_flags.quantize_weights.bind(),
parsed_flags.quantize_weights.default_value(),
- "Store weights as quantized weights followed by dequantize "
- "operations. Computation is still done in float, but reduces model "
- "size (at the cost of accuracy and latency)."),
- };
+ "Deprecated. Please use --post_training_quantize instead."),
+ Flag("post_training_quantize", parsed_flags.post_training_quantize.bind(),
+ parsed_flags.post_training_quantize.default_value(),
+ "Boolean indicating whether to quantize the weights of the "
+ "converted float model. Model size will be reduced and there will "
+ "be latency improvements (at the cost of accuracy).")};
bool asked_for_help =
*argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
if (asked_for_help) {
@@ -257,6 +259,7 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags,
READ_TOCO_FLAG(dedupe_array_min_size_bytes, FlagRequirement::kNone);
READ_TOCO_FLAG(split_tflite_lstm_inputs, FlagRequirement::kNone);
READ_TOCO_FLAG(quantize_weights, FlagRequirement::kNone);
+ READ_TOCO_FLAG(post_training_quantize, FlagRequirement::kNone);
// Deprecated flag handling.
if (parsed_toco_flags.input_type.specified()) {
@@ -291,9 +294,19 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags,
toco_flags->set_inference_input_type(input_type);
}
if (parsed_toco_flags.quantize_weights.value()) {
- QCHECK_NE(toco_flags->inference_type(), IODataType::QUANTIZED_UINT8)
- << "quantize_weights is not supported with inference_type "
- "QUANTIZED_UINT8.";
+ LOG(WARNING)
+ << "--quantize_weights is deprecated. Falling back to "
+ "--post_training_quantize. Please switch --post_training_quantize.";
+ toco_flags->set_post_training_quantize(
+ parsed_toco_flags.quantize_weights.value());
+ }
+ if (parsed_toco_flags.quantize_weights.value()) {
+ if (toco_flags->inference_type() == IODataType::QUANTIZED_UINT8) {
+ LOG(WARNING)
+ << "--post_training_quantize quantizes a graph of inference_type "
+ "FLOAT. Overriding inference type QUANTIZED_UINT8 to FLOAT.";
+ toco_flags->set_inference_type(IODataType::FLOAT);
+ }
}
#undef READ_TOCO_FLAG
diff --git a/tensorflow/contrib/lite/toco/toco_flags.proto b/tensorflow/contrib/lite/toco/toco_flags.proto
index b4a9870d58..c1dd621429 100644
--- a/tensorflow/contrib/lite/toco/toco_flags.proto
+++ b/tensorflow/contrib/lite/toco/toco_flags.proto
@@ -37,7 +37,7 @@ enum FileFormat {
// of as properties of models, instead describing how models are to be
// processed in the context of the present tooling job.
//
-// Next ID to use: 26.
+// Next ID to use: 27.
message TocoFlags {
// Input file format
optional FileFormat input_format = 1;
@@ -173,6 +173,7 @@ message TocoFlags {
// Store weights as quantized weights followed by dequantize operations.
// Computation is still done in float, but reduces model size (at the cost of
// accuracy and latency).
+ // DEPRECATED: Please use post_training_quantize instead.
optional bool quantize_weights = 20 [default = false];
// Full filepath of folder to dump the graphs at various stages of processing
@@ -183,4 +184,9 @@ message TocoFlags {
// Boolean indicating whether to dump the graph after every graph
// transformation.
optional bool dump_graphviz_include_video = 25;
+
+ // Boolean indicating whether to quantize the weights of the converted float
+ // model. Model size will be reduced and there will be latency improvements
+ // (at the cost of accuracy).
+ optional bool post_training_quantize = 26 [default = false];
}
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
index 243d0dabdb..7db7acb44d 100644
--- a/tensorflow/contrib/lite/toco/toco_tooling.cc
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -399,7 +399,8 @@ void Export(const TocoFlags& toco_flags, const Model& model,
break;
case TFLITE:
toco::tflite::Export(model, allow_custom_ops,
- toco_flags.quantize_weights(), output_file_contents);
+ toco_flags.post_training_quantize(),
+ output_file_contents);
break;
case GRAPHVIZ_DOT:
DumpGraphviz(model, output_file_contents);