aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Suharsh Sivakumar <suharshs@google.com>2018-07-06 14:01:23 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-07 20:43:57 -0700
commit2d4a76a3df4cb9c4466de3b27adc0c0217d2e59f (patch)
treec5a9524e0367218a1d8b3740016cef153095f563
parent3a8b3f585b0562a9f4913373a12b0e92cddf1589 (diff)
Call QuantizeWeights transformation *after* batchnorms have been folded.
PiperOrigin-RevId: 203521700
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.cc9
1 files changed, 6 insertions, 3 deletions
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
index fc1636831b..3ca36338eb 100644
--- a/tensorflow/contrib/lite/toco/toco_tooling.cc
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -273,13 +273,16 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
transformations.Add(new toco::MergeLstmCellInputs);
}
}
- if (toco_flags.quantize_weights()) {
- transformations.Add(new QuantizeWeights);
- }
transformations.Add(new ResolveConstantConcatenation);
RunGraphTransformations(model, "general graph transformations",
transformations);
+ if (toco_flags.quantize_weights()) {
+ // Run the quantize weights transformation after batchnorms have been
+ // folded into the weights.
+ RunGraphTransformations(model, "quantize weights transformation",
+ {new QuantizeWeights});
+ }
if (quantize_output) {
if (toco_flags.propagate_fake_quant_num_bits()) {
RunGraphTransformations(model,