aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-17 10:59:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-17 11:02:57 -0700
commit532e7210672c670529f1c27404c04eb2a6a1035f (patch)
treeb53cb0f7e1a59b5b6109612929a1f9c3e4c804bd
parentfe48d9189c74e05be21e928a445a006bcd4a13f8 (diff)
Propagate narrow_range flags in suitable constant-propagation
transformations. PiperOrigin-RevId: 209174233
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc1
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc8
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tile.cc10
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_transpose.cc8
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc10
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.h3
6 files changed, 17 insertions, 23 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc
index d395d7a6a0..f5f2f77460 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc
@@ -117,6 +117,7 @@ bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) {
&quantized_max);
if (fakequant_op->narrow_range) {
quantized_min++;
+ output_array.narrow_range = true;
}
// It is important for matching accuracy between TF training and TFLite
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc
index 41562ab393..a6f665b5f0 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc
@@ -100,13 +100,7 @@ bool ResolveConstantReshape::Run(Model* model, std::size_t op_index) {
AddMessageF("Resolving constant reshape of %s", LogName(*op));
- if (input_array.minmax) {
- output_array.GetOrCreateMinMax() = input_array.GetMinMax();
- }
- if (input_array.quantization_params) {
- output_array.GetOrCreateQuantizationParams() =
- input_array.GetQuantizationParams();
- }
+ CopyMinMaxAndQuantizationRelatedFields(input_array, &output_array);
// Erase input arrays if no longer used.
for (const auto& input : op->inputs) {
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tile.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tile.cc
index 0b0d070714..5cfa1a5582 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tile.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tile.cc
@@ -128,15 +128,7 @@ bool ResolveConstantTile::Run(Model* model, std::size_t op_index) {
multiples_array.data_type == ArrayDataType::kInt64)
<< "Only int32/int64 indices are supported";
- // Copy min/max info if present. The ranges of the selected values may be
- // a subset of the original range but we want to ensure the quantization
- // params stay the same.
- if (input_array.minmax) {
- const auto& input_minmax = input_array.GetMinMax();
- auto& output_minmax = output_array.GetOrCreateMinMax();
- output_minmax.min = input_minmax.min;
- output_minmax.max = input_minmax.max;
- }
+ CopyMinMaxAndQuantizationRelatedFields(input_array, &output_array);
CHECK(!output_array.buffer);
switch (output_array.data_type) {
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_transpose.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_transpose.cc
index 1fd20314b1..fe15dfa06f 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_transpose.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_transpose.cc
@@ -128,13 +128,7 @@ bool ResolveConstantTranspose::Run(Model* model, std::size_t op_index) {
}
const Array& input_array = model->GetArray(op->inputs[0]);
- if (input_array.minmax) {
- output_array.GetOrCreateMinMax() = input_array.GetMinMax();
- }
- if (input_array.quantization_params) {
- output_array.GetOrCreateQuantizationParams() =
- input_array.GetQuantizationParams();
- }
+ CopyMinMaxAndQuantizationRelatedFields(input_array, &output_array);
if (op->perm.empty()) {
// Yield until perm has been populated by ResolveTransposeAttributes.
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index 2ad2719811..3a4542f522 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -2278,4 +2278,14 @@ void UndoWeightsShuffling(Model* model) {
}
}
+void CopyMinMaxAndQuantizationRelatedFields(const Array& src, Array* dst) {
+ if (src.minmax) {
+ dst->GetOrCreateMinMax() = src.GetMinMax();
+ }
+ if (src.quantization_params) {
+ dst->GetOrCreateQuantizationParams() = src.GetQuantizationParams();
+ }
+ dst->narrow_range = src.narrow_range;
+}
+
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h
index b99e6111fe..bdeb203024 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.h
+++ b/tensorflow/contrib/lite/toco/tooling_util.h
@@ -348,6 +348,9 @@ tensorflow::Status NumElements(const std::vector<T>& shape, U* num_elements) {
// so that the rest of toco doesn't need to know about shuffled weights.
void UndoWeightsShuffling(Model* model);
+// Copies minmax, quantization_params, and narrow_range.
+void CopyMinMaxAndQuantizationRelatedFields(const Array& src, Array* dst);
+
} // namespace toco
#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TOOLING_UTIL_H_