aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/tools
diff options
context:
space:
mode:
authorGravatar Suharsh Sivakumar <suharshs@google.com>2018-08-31 14:44:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-31 15:00:21 -0700
commitbd3ef4faf91aab31cbee3cbe586a5e1b277afbb6 (patch)
treea4e1114bf9029640078a192db72bd7eb6a02c500 /tensorflow/contrib/lite/tools
parente082d5208e56d3d8f69544781bebf830eae82de7 (diff)
Hybrid operations need either all or none of their tensors quantized.
PiperOrigin-RevId: 211147312
Diffstat (limited to 'tensorflow/contrib/lite/tools')
-rw-r--r--tensorflow/contrib/lite/tools/optimize/quantize_weights.cc10
1 files changed, 10 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc b/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc
index 10a66dd351..df8433dd9b 100644
--- a/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc
+++ b/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc
@@ -168,6 +168,7 @@ std::vector<TensorInfo> GetQuantizableTensorsFromOperator(const ModelT* model,
bool eval_hybrid = IsHybridEvaluationOp(op, op_code);
+ bool skipped_tensor = false;
std::vector<int32_t> op_input_indices = GetWeightInputIndices(op_code);
for (const int32_t op_input_idx : op_input_indices) {
int32_t tensor_idx = op->inputs[op_input_idx];
@@ -177,6 +178,7 @@ std::vector<TensorInfo> GetQuantizableTensorsFromOperator(const ModelT* model,
if (CountTensorConsumers(model, subgraph, tensor_idx) != 1) {
LOG(INFO) << "Skipping quantization of tensor that is shared between "
"multiple multiple operations.";
+ skipped_tensor = true;
continue;
}
@@ -184,6 +186,7 @@ std::vector<TensorInfo> GetQuantizableTensorsFromOperator(const ModelT* model,
if (tensor->type != TensorType_FLOAT32) {
LOG(INFO) << "Skipping quantization of tensor that is not type float.";
+ skipped_tensor = true;
continue;
}
@@ -191,6 +194,7 @@ std::vector<TensorInfo> GetQuantizableTensorsFromOperator(const ModelT* model,
if (num_elements < kWeightsMinSize) {
LOG(INFO) << "Skipping quantization of tensor because it has fewer than "
<< kWeightsMinSize << " elements (" << num_elements << ").";
+ skipped_tensor = true;
continue;
}
@@ -203,6 +207,12 @@ std::vector<TensorInfo> GetQuantizableTensorsFromOperator(const ModelT* model,
tensor_infos.push_back(tensor_info);
}
+ // For hybrid operations we either need to quantize all tensors or none. So
+ // if we skipped any tensors we need to return no quantized tensors.
+ if (eval_hybrid && skipped_tensor) {
+ return {};
+ }
+
return tensor_infos;
}