aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-17 11:53:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-17 11:55:26 -0700
commitd7b6cb66c0fc346cf55020042931c07208713c60 (patch)
tree9024111ebf15d12a631ffd7e176b9da7459dd5a0 /tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc
parent1192c1662c5c98f55805450b4619ac2bc9c6908c (diff)
Fixes and cleanup to support more complex quantized models and adds PropagateFakeQuantNumBits.
PiperOrigin-RevId: 193232630
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc25
1 files changed, 24 insertions, 1 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc
index 625d90205a..efb7bb2184 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h"
#include "tensorflow/contrib/lite/toco/model.h"
#include "tensorflow/contrib/lite/toco/tooling_util.h"
#include "tensorflow/core/platform/logging.h"
@@ -45,9 +46,29 @@ bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) {
}
const auto& input_array = model->GetArray(fakequant_op->inputs[0]);
+ CHECK(input_array.data_type == ArrayDataType::kFloat);
+
+ // Determine the final data type in the same way as PropagateFakeQuantNumBits.
+ ArrayDataType quantized_data_type = input_array.final_data_type;
+ if (!InferQuantizedDataTypeFromFakeQuant(*fakequant_op,
+ &quantized_data_type)) {
+ AddMessageF("Unsupported FakeQuant num_bits=%d", fakequant_op->num_bits);
+ return false;
+ }
+
+ AddMessageF("Resolving constant %s", LogName(*fakequant_op));
+
auto& output_array = model->GetArray(fakequant_op->outputs[0]);
CHECK(input_array.data_type == ArrayDataType::kFloat);
output_array.data_type = ArrayDataType::kFloat;
+
+ // We'll set the final data type to what the fake quant indicates we should
+ // have (and would have been set if this stayed around until
+ // PropagateFakeQuantNumBits).
+ if (propagate_fake_quant_num_bits()) {
+ output_array.final_data_type = quantized_data_type;
+ }
+
CHECK(!output_array.buffer);
const auto& input_buffer = input_array.GetBuffer<ArrayDataType::kFloat>();
output_array.GetOrCreateMinMax() = *fakequant_op->minmax;
@@ -66,7 +87,9 @@ bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) {
const double dst_val = qparams.scale * (quantized_val - qparams.zero_point);
output_buffer.data[i] = dst_val;
}
- if (CountOpsWithInput(*model, fakequant_op->inputs[0]) == 1) {
+
+ if (IsDiscardableArray(*model, fakequant_op->inputs[0]) &&
+ CountOpsWithInput(*model, fakequant_op->inputs[0]) == 1) {
model->EraseArray(fakequant_op->inputs[0]);
}
model->operators.erase(fakequant_it);