diff options
author | 2018-04-17 11:53:29 -0700 | |
---|---|---|
committer | 2018-04-17 11:55:26 -0700 | |
commit | d7b6cb66c0fc346cf55020042931c07208713c60 (patch) | |
tree | 9024111ebf15d12a631ffd7e176b9da7459dd5a0 /tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc | |
parent | 1192c1662c5c98f55805450b4619ac2bc9c6908c (diff) |
Fixes and cleanup to support more complex quantized models and adds PropagateFakeQuantNumBits.
PiperOrigin-RevId: 193232630
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc | 25 |
1 files changed, 24 insertions, 1 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc index 625d90205a..efb7bb2184 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc @@ -18,6 +18,7 @@ limitations under the License. #include <vector> #include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h" #include "tensorflow/contrib/lite/toco/model.h" #include "tensorflow/contrib/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" @@ -45,9 +46,29 @@ bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) { } const auto& input_array = model->GetArray(fakequant_op->inputs[0]); + CHECK(input_array.data_type == ArrayDataType::kFloat); + + // Determine the final data type in the same way as PropagateFakeQuantNumBits. + ArrayDataType quantized_data_type = input_array.final_data_type; + if (!InferQuantizedDataTypeFromFakeQuant(*fakequant_op, + &quantized_data_type)) { + AddMessageF("Unsupported FakeQuant num_bits=%d", fakequant_op->num_bits); + return false; + } + + AddMessageF("Resolving constant %s", LogName(*fakequant_op)); + auto& output_array = model->GetArray(fakequant_op->outputs[0]); CHECK(input_array.data_type == ArrayDataType::kFloat); output_array.data_type = ArrayDataType::kFloat; + + // We'll set the final data type to what the fake quant indicates we should + // have (and would have been set if this stayed around until + // PropagateFakeQuantNumBits). + if (propagate_fake_quant_num_bits()) { + output_array.final_data_type = quantized_data_type; + } + CHECK(!output_array.buffer); const auto& input_buffer = input_array.GetBuffer<ArrayDataType::kFloat>(); output_array.GetOrCreateMinMax() = *fakequant_op->minmax; @@ -66,7 +87,9 @@ bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) { const double dst_val = qparams.scale * (quantized_val - qparams.zero_point); output_buffer.data[i] = dst_val; } - if (CountOpsWithInput(*model, fakequant_op->inputs[0]) == 1) { + + if (IsDiscardableArray(*model, fakequant_op->inputs[0]) && + CountOpsWithInput(*model, fakequant_op->inputs[0]) == 1) { model->EraseArray(fakequant_op->inputs[0]); } model->operators.erase(fakequant_it); |