aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/tooling_util.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/tooling_util.cc')
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc11
1 files changed, 11 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index 8543ba4742..99a54a300b 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -197,6 +197,7 @@ const char* OperatorTypeName(OperatorType type) {
case OperatorType::k##c: \
return #c;
HANDLE_OPERATORTYPENAME_CASE(Add)
+ HANDLE_OPERATORTYPENAME_CASE(AddN)
HANDLE_OPERATORTYPENAME_CASE(AveragePool)
HANDLE_OPERATORTYPENAME_CASE(BatchNormalization)
HANDLE_OPERATORTYPENAME_CASE(Conv)
@@ -1396,6 +1397,16 @@ bool EstimateArithmeticOpsCount(const Model& model, int64* result) {
total += RequiredBufferSizeForShape(output_array.shape());
break;
}
+ case OperatorType::kAddN: {
+ const auto& output_array = model.GetArray(op->outputs[0]);
+ if (!output_array.has_shape()) {
+ return false;
+ }
+ // AddN cost is roughly the same cost as N-1 Adds.
+ const int num_adds = op->inputs.size() - 1;
+ total += num_adds * RequiredBufferSizeForShape(output_array.shape());
+ break;
+ }
case OperatorType::kLogistic:
case OperatorType::kSoftmax:
case OperatorType::kTanh: {