diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/toco_tooling.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/toco_tooling.cc | 20 |
1 files changed, 13 insertions, 7 deletions
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc index a057dcef12..aa7f6996eb 100644 --- a/tensorflow/contrib/lite/toco/toco_tooling.cc +++ b/tensorflow/contrib/lite/toco/toco_tooling.cc @@ -55,7 +55,7 @@ void MakeGeneralGraphTransformationsSet( transformations->Add(new ConvertExpandDimsToReshape); transformations->Add(new ConvertSqueezeToReshape); transformations->Add(new ConvertTrivialAddNToAdd); - transformations->Add(new ConvertTrivialStackToReshape); + transformations->Add(new ConvertTrivialPackToReshape); transformations->Add(new ConvertTrivialTileToConcat); transformations->Add(new ConvertTrivialTransposeToReshape); transformations->Add(new ConvertReorderAxes); @@ -79,17 +79,18 @@ void MakeGeneralGraphTransformationsSet( transformations->Add(new FuseBinaryIntoFollowingAffine); transformations->Add(new FuseBroadcastIntoFollowingBinary); transformations->Add(new MergeReshapeIntoPrecedingTranspose); + transformations->Add(new MoveBinaryOperatorBeforeReshape); transformations->Add(new ReorderElementwiseUnary); transformations->Add(new ReorderReshapeTranspose); transformations->Add(new ResolveBatchNormalization); transformations->Add(new ResolveConstantBinaryOperator); transformations->Add(new ResolveConstantFill); transformations->Add(new ResolveConstantGather); + transformations->Add(new ResolveConstantPack); transformations->Add(new ResolveConstantRandomUniform); transformations->Add(new ResolveConstantRange); transformations->Add(new ResolveConstantReshape); transformations->Add(new ResolveConstantSlice); - transformations->Add(new ResolveConstantStack); transformations->Add(new ResolveConstantStridedSlice); transformations->Add(new ResolveConstantTranspose); transformations->Add(new ResolveConstantUnaryOperator); @@ -104,17 +105,19 @@ void MakeGeneralGraphTransformationsSet( transformations->Add(new IdentifyRelu1); transformations->Add(new IdentifyPRelu); transformations->Add(new RemoveTrivialBinaryOperator); - transformations->Add(new ReadFakeQuantMinMax); + transformations->Add(new ResolveFakeQuantArgsFromVars); + transformations->Add(new ReadArrayMinmaxAndNarrowRangeFromFakeQuant); transformations->Add(new ResolveSpaceToBatchNDAttributes); transformations->Add(new ResolveBatchToSpaceNDAttributes); transformations->Add(new ResolvePadAttributes); transformations->Add(new ResolvePadV2Attributes); transformations->Add(new ResolveStridedSliceAttributes); transformations->Add(new ResolveSliceAttributes); - transformations->Add(new ResolveMeanAttributes); + transformations->Add(new ResolveReduceAttributes); transformations->Add(new ResolveConstantShapeOrRank); transformations->Add(new MakeInitialDequantizeOperator); transformations->Add(new UnpartitionEmbeddingLookup); + transformations->Add(new ResolveGatherAttributes); } bool SupportsQuantization(FileFormat format) { @@ -272,13 +275,16 @@ void Transform(const TocoFlags& toco_flags, Model* model) { transformations.Add(new toco::MergeLstmCellInputs); } } - if (toco_flags.quantize_weights()) { - transformations.Add(new QuantizeWeights); - } transformations.Add(new ResolveConstantConcatenation); RunGraphTransformations(model, "general graph transformations", transformations); + if (toco_flags.quantize_weights()) { + // Run the quantize weights transformation after batchnorms have been + // folded into the weights. + RunGraphTransformations(model, "quantize weights transformation", + {new QuantizeWeights}); + } if (quantize_output) { if (toco_flags.propagate_fake_quant_num_bits()) { RunGraphTransformations(model, |