diff options
6 files changed, 17 insertions, 23 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc index d395d7a6a0..f5f2f77460 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc @@ -117,6 +117,7 @@ bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) { &quantized_max); if (fakequant_op->narrow_range) { quantized_min++; + output_array.narrow_range = true; } // It is important for matching accuracy between TF training and TFLite diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc index 41562ab393..a6f665b5f0 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc @@ -100,13 +100,7 @@ bool ResolveConstantReshape::Run(Model* model, std::size_t op_index) { AddMessageF("Resolving constant reshape of %s", LogName(*op)); - if (input_array.minmax) { - output_array.GetOrCreateMinMax() = input_array.GetMinMax(); - } - if (input_array.quantization_params) { - output_array.GetOrCreateQuantizationParams() = - input_array.GetQuantizationParams(); - } + CopyMinMaxAndQuantizationRelatedFields(input_array, &output_array); // Erase input arrays if no longer used. for (const auto& input : op->inputs) { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tile.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tile.cc index 0b0d070714..5cfa1a5582 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tile.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tile.cc @@ -128,15 +128,7 @@ bool ResolveConstantTile::Run(Model* model, std::size_t op_index) { multiples_array.data_type == ArrayDataType::kInt64) << "Only int32/int64 indices are supported"; - // Copy min/max info if present. The ranges of the selected values may be - // a subset of the original range but we want to ensure the quantization - // params stay the same. - if (input_array.minmax) { - const auto& input_minmax = input_array.GetMinMax(); - auto& output_minmax = output_array.GetOrCreateMinMax(); - output_minmax.min = input_minmax.min; - output_minmax.max = input_minmax.max; - } + CopyMinMaxAndQuantizationRelatedFields(input_array, &output_array); CHECK(!output_array.buffer); switch (output_array.data_type) { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_transpose.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_transpose.cc index 1fd20314b1..fe15dfa06f 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_transpose.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_transpose.cc @@ -128,13 +128,7 @@ bool ResolveConstantTranspose::Run(Model* model, std::size_t op_index) { } const Array& input_array = model->GetArray(op->inputs[0]); - if (input_array.minmax) { - output_array.GetOrCreateMinMax() = input_array.GetMinMax(); - } - if (input_array.quantization_params) { - output_array.GetOrCreateQuantizationParams() = - input_array.GetQuantizationParams(); - } + CopyMinMaxAndQuantizationRelatedFields(input_array, &output_array); if (op->perm.empty()) { // Yield until perm has been populated by ResolveTransposeAttributes. diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index 2ad2719811..3a4542f522 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -2278,4 +2278,14 @@ void UndoWeightsShuffling(Model* model) { } } +void CopyMinMaxAndQuantizationRelatedFields(const Array& src, Array* dst) { + if (src.minmax) { + dst->GetOrCreateMinMax() = src.GetMinMax(); + } + if (src.quantization_params) { + dst->GetOrCreateQuantizationParams() = src.GetQuantizationParams(); + } + dst->narrow_range = src.narrow_range; +} + } // namespace toco diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h index b99e6111fe..bdeb203024 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.h +++ b/tensorflow/contrib/lite/toco/tooling_util.h @@ -348,6 +348,9 @@ tensorflow::Status NumElements(const std::vector<T>& shape, U* num_elements) { // so that the rest of toco doesn't need to know about shuffled weights. void UndoWeightsShuffling(Model* model); +// Copies minmax, quantization_params, and narrow_range. +void CopyMinMaxAndQuantizationRelatedFields(const Array& src, Array* dst); + } // namespace toco #endif // TENSORFLOW_CONTRIB_LITE_TOCO_TOOLING_UTIL_H_ |