aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/toco_tooling.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/toco_tooling.cc')
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.cc26
1 files changed, 16 insertions, 10 deletions
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
index 76e9a27aef..96c5ebd64f 100644
--- a/tensorflow/contrib/lite/toco/toco_tooling.cc
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -130,20 +130,26 @@ bool SupportsPreallocatedWorkspace(FileFormat format) {
}
bool IsRealValued(toco::ArrayDataType type) {
+ // TODO(benoitjacob) - this is hardcoding that uint8 and int16 are only used
+ // for quantized real-number values, and no other integer type is ever used
+ // for that. This is dirty, should be resolved as part of a more general push
+ // to more explicitly distinguish between true-integers and
+ // integers used as quantized values representing real numbers.
return static_cast<bool>(type == toco::ArrayDataType::kFloat ||
- type == toco::ArrayDataType::kUint8);
+ type == toco::ArrayDataType::kUint8 ||
+ type == toco::ArrayDataType::kInt16);
}
void SetFinalDataTypeOnInputs(const TocoFlags& toco_flags, Model* model) {
const FileFormat output_format = toco_flags.output_format();
ArrayDataType type;
- if (toco_flags.has_inference_input_type()) {
+ if (!SupportsQuantization(output_format)) {
+ // Data type is implicitly float for non-quantized formats
+ type = ArrayDataType::kFloat;
+ } else if (toco_flags.has_inference_input_type()) {
type = ConvertIODataTypeToArrayDataType(toco_flags.inference_input_type());
} else if (toco_flags.has_inference_type()) {
type = ConvertIODataTypeToArrayDataType(toco_flags.inference_type());
- } else if (!SupportsQuantization(output_format)) {
- // Data type is implicitly float for non-quantized formats
- type = ArrayDataType::kFloat;
} else {
// Nothing to do. Data types stay as-is.
return;
@@ -198,11 +204,6 @@ std::unique_ptr<Model> Import(const TocoFlags& toco_flags,
}
void Transform(const TocoFlags& toco_flags, Model* model) {
- // Clean up after import.
- SetFinalDataTypeOnInputs(toco_flags, model);
- UseArraysExtraInfo(model);
- FinishBuildingRNNStates(model);
-
const FileFormat output_format = toco_flags.output_format();
const IODataType inference_type = toco_flags.inference_type();
@@ -215,6 +216,11 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
<< "Quantized inference is not allowed with float inputs.";
}
+ // Clean up after import.
+ SetFinalDataTypeOnInputs(toco_flags, model);
+ UseArraysExtraInfo(model, quantize_output);
+ FinishBuildingRNNStates(model);
+
// Remove unused ops before performing any other optimizations. This is to
// stop optimizations from crossing the input/output boundaries. For example
// this will stop BatchNorm fusing if the output node is in between a conv