aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/toco_tooling.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/toco_tooling.cc')
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.cc20
1 files changed, 13 insertions, 7 deletions
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
index a057dcef12..aa7f6996eb 100644
--- a/tensorflow/contrib/lite/toco/toco_tooling.cc
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -55,7 +55,7 @@ void MakeGeneralGraphTransformationsSet(
transformations->Add(new ConvertExpandDimsToReshape);
transformations->Add(new ConvertSqueezeToReshape);
transformations->Add(new ConvertTrivialAddNToAdd);
- transformations->Add(new ConvertTrivialStackToReshape);
+ transformations->Add(new ConvertTrivialPackToReshape);
transformations->Add(new ConvertTrivialTileToConcat);
transformations->Add(new ConvertTrivialTransposeToReshape);
transformations->Add(new ConvertReorderAxes);
@@ -79,17 +79,18 @@ void MakeGeneralGraphTransformationsSet(
transformations->Add(new FuseBinaryIntoFollowingAffine);
transformations->Add(new FuseBroadcastIntoFollowingBinary);
transformations->Add(new MergeReshapeIntoPrecedingTranspose);
+ transformations->Add(new MoveBinaryOperatorBeforeReshape);
transformations->Add(new ReorderElementwiseUnary);
transformations->Add(new ReorderReshapeTranspose);
transformations->Add(new ResolveBatchNormalization);
transformations->Add(new ResolveConstantBinaryOperator);
transformations->Add(new ResolveConstantFill);
transformations->Add(new ResolveConstantGather);
+ transformations->Add(new ResolveConstantPack);
transformations->Add(new ResolveConstantRandomUniform);
transformations->Add(new ResolveConstantRange);
transformations->Add(new ResolveConstantReshape);
transformations->Add(new ResolveConstantSlice);
- transformations->Add(new ResolveConstantStack);
transformations->Add(new ResolveConstantStridedSlice);
transformations->Add(new ResolveConstantTranspose);
transformations->Add(new ResolveConstantUnaryOperator);
@@ -104,17 +105,19 @@ void MakeGeneralGraphTransformationsSet(
transformations->Add(new IdentifyRelu1);
transformations->Add(new IdentifyPRelu);
transformations->Add(new RemoveTrivialBinaryOperator);
- transformations->Add(new ReadFakeQuantMinMax);
+ transformations->Add(new ResolveFakeQuantArgsFromVars);
+ transformations->Add(new ReadArrayMinmaxAndNarrowRangeFromFakeQuant);
transformations->Add(new ResolveSpaceToBatchNDAttributes);
transformations->Add(new ResolveBatchToSpaceNDAttributes);
transformations->Add(new ResolvePadAttributes);
transformations->Add(new ResolvePadV2Attributes);
transformations->Add(new ResolveStridedSliceAttributes);
transformations->Add(new ResolveSliceAttributes);
- transformations->Add(new ResolveMeanAttributes);
+ transformations->Add(new ResolveReduceAttributes);
transformations->Add(new ResolveConstantShapeOrRank);
transformations->Add(new MakeInitialDequantizeOperator);
transformations->Add(new UnpartitionEmbeddingLookup);
+ transformations->Add(new ResolveGatherAttributes);
}
bool SupportsQuantization(FileFormat format) {
@@ -272,13 +275,16 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
transformations.Add(new toco::MergeLstmCellInputs);
}
}
- if (toco_flags.quantize_weights()) {
- transformations.Add(new QuantizeWeights);
- }
transformations.Add(new ResolveConstantConcatenation);
RunGraphTransformations(model, "general graph transformations",
transformations);
+ if (toco_flags.quantize_weights()) {
+ // Run the quantize weights transformation after batchnorms have been
+ // folded into the weights.
+ RunGraphTransformations(model, "quantize weights transformation",
+ {new QuantizeWeights});
+ }
if (quantize_output) {
if (toco_flags.propagate_fake_quant_num_bits()) {
RunGraphTransformations(model,