diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/toco_tooling.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/toco_tooling.cc | 26 |
1 files changed, 16 insertions, 10 deletions
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc index 76e9a27aef..96c5ebd64f 100644 --- a/tensorflow/contrib/lite/toco/toco_tooling.cc +++ b/tensorflow/contrib/lite/toco/toco_tooling.cc @@ -130,20 +130,26 @@ bool SupportsPreallocatedWorkspace(FileFormat format) { } bool IsRealValued(toco::ArrayDataType type) { + // TODO(benoitjacob) - this is hardcoding that uint8 and int16 are only used + // for quantized real-number values, and no other integer type is ever used + // for that. This is dirty, should be resolved as part of a more general push + // to more explicitly distinguish between true-integers and + // integers used as quantized values representing real numbers. return static_cast<bool>(type == toco::ArrayDataType::kFloat || - type == toco::ArrayDataType::kUint8); + type == toco::ArrayDataType::kUint8 || + type == toco::ArrayDataType::kInt16); } void SetFinalDataTypeOnInputs(const TocoFlags& toco_flags, Model* model) { const FileFormat output_format = toco_flags.output_format(); ArrayDataType type; - if (toco_flags.has_inference_input_type()) { + if (!SupportsQuantization(output_format)) { + // Data type is implicitly float for non-quantized formats + type = ArrayDataType::kFloat; + } else if (toco_flags.has_inference_input_type()) { type = ConvertIODataTypeToArrayDataType(toco_flags.inference_input_type()); } else if (toco_flags.has_inference_type()) { type = ConvertIODataTypeToArrayDataType(toco_flags.inference_type()); - } else if (!SupportsQuantization(output_format)) { - // Data type is implicitly float for non-quantized formats - type = ArrayDataType::kFloat; } else { // Nothing to do. Data types stay as-is. return; @@ -198,11 +204,6 @@ std::unique_ptr<Model> Import(const TocoFlags& toco_flags, } void Transform(const TocoFlags& toco_flags, Model* model) { - // Clean up after import. - SetFinalDataTypeOnInputs(toco_flags, model); - UseArraysExtraInfo(model); - FinishBuildingRNNStates(model); - const FileFormat output_format = toco_flags.output_format(); const IODataType inference_type = toco_flags.inference_type(); @@ -215,6 +216,11 @@ void Transform(const TocoFlags& toco_flags, Model* model) { << "Quantized inference is not allowed with float inputs."; } + // Clean up after import. + SetFinalDataTypeOnInputs(toco_flags, model); + UseArraysExtraInfo(model, quantize_output); + FinishBuildingRNNStates(model); + // Remove unused ops before performing any other optimizations. This is to // stop optimizations from crossing the input/output boundaries. For example // this will stop BatchNorm fusing if the output node is in between a conv |