diff options
author | 2018-07-06 14:01:23 -0700 | |
---|---|---|
committer | 2018-07-07 20:43:57 -0700 | |
commit | 2d4a76a3df4cb9c4466de3b27adc0c0217d2e59f (patch) | |
tree | c5a9524e0367218a1d8b3740016cef153095f563 | |
parent | 3a8b3f585b0562a9f4913373a12b0e92cddf1589 (diff) |
Call QuantizeWeights transformation *after* batchnorms have been folded.
PiperOrigin-RevId: 203521700
-rw-r--r-- | tensorflow/contrib/lite/toco/toco_tooling.cc | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc index fc1636831b..3ca36338eb 100644 --- a/tensorflow/contrib/lite/toco/toco_tooling.cc +++ b/tensorflow/contrib/lite/toco/toco_tooling.cc @@ -273,13 +273,16 @@ void Transform(const TocoFlags& toco_flags, Model* model) { transformations.Add(new toco::MergeLstmCellInputs); } } - if (toco_flags.quantize_weights()) { - transformations.Add(new QuantizeWeights); - } transformations.Add(new ResolveConstantConcatenation); RunGraphTransformations(model, "general graph transformations", transformations); + if (toco_flags.quantize_weights()) { + // Run the quantize weights transformation after batchnorms have been + // folded into the weights. + RunGraphTransformations(model, "quantize weights transformation", + {new QuantizeWeights}); + } if (quantize_output) { if (toco_flags.propagate_fake_quant_num_bits()) { RunGraphTransformations(model, |