aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc24
1 files changed, 16 insertions, 8 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc
index d916ae0ddf..0c60fdfeb3 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc
@@ -135,11 +135,14 @@ void SetMinMaxForConcatenedArray(GraphTransformation* transformation,
} // namespace
// Resolves the concatenation operator if all its inputs are constant arrays.
-bool ResolveConstantConcatenation::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ResolveConstantConcatenation::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto concat_it = model->operators.begin() + op_index;
const auto* concat_base_op = concat_it->get();
if (concat_base_op->type != OperatorType::kConcatenation) {
- return false;
+ return ::tensorflow::Status::OK();
}
const auto* concat_op =
static_cast<const ConcatenationOperator*>(concat_base_op);
@@ -149,11 +152,15 @@ bool ResolveConstantConcatenation::Run(Model* model, std::size_t op_index) {
// We also make sure the shapes of the input arrays are known and they are
// all discardable.
const Operator* input_op = GetOpWithOutput(*model, input_name);
- if (input_op) return false;
- if (!IsConstantParameterArray(*model, input_name)) return false;
- if (!model->GetArray(input_name).has_shape()) return false;
- if (model->GetArray(input_name).quantization_params) return false;
- if (!IsDiscardableArray(*model, input_name)) return false;
+ if (input_op) return ::tensorflow::Status::OK();
+ if (!IsConstantParameterArray(*model, input_name))
+ return ::tensorflow::Status::OK();
+ if (!model->GetArray(input_name).has_shape())
+ return ::tensorflow::Status::OK();
+ if (model->GetArray(input_name).quantization_params)
+ return ::tensorflow::Status::OK();
+ if (!IsDiscardableArray(*model, input_name))
+ return ::tensorflow::Status::OK();
}
const int concatenation_axis = concat_op->axis;
@@ -205,7 +212,8 @@ bool ResolveConstantConcatenation::Run(Model* model, std::size_t op_index) {
// Remove concatenate operator.
model->operators.erase(concat_it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco