diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc | 24 |
1 files changed, 16 insertions, 8 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc index d916ae0ddf..0c60fdfeb3 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc @@ -135,11 +135,14 @@ void SetMinMaxForConcatenedArray(GraphTransformation* transformation, } // namespace // Resolves the concatenation operator if all its inputs are constant arrays. -bool ResolveConstantConcatenation::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveConstantConcatenation::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto concat_it = model->operators.begin() + op_index; const auto* concat_base_op = concat_it->get(); if (concat_base_op->type != OperatorType::kConcatenation) { - return false; + return ::tensorflow::Status::OK(); } const auto* concat_op = static_cast<const ConcatenationOperator*>(concat_base_op); @@ -149,11 +152,15 @@ bool ResolveConstantConcatenation::Run(Model* model, std::size_t op_index) { // We also make sure the shapes of the input arrays are known and they are // all discardable. const Operator* input_op = GetOpWithOutput(*model, input_name); - if (input_op) return false; - if (!IsConstantParameterArray(*model, input_name)) return false; - if (!model->GetArray(input_name).has_shape()) return false; - if (model->GetArray(input_name).quantization_params) return false; - if (!IsDiscardableArray(*model, input_name)) return false; + if (input_op) return ::tensorflow::Status::OK(); + if (!IsConstantParameterArray(*model, input_name)) + return ::tensorflow::Status::OK(); + if (!model->GetArray(input_name).has_shape()) + return ::tensorflow::Status::OK(); + if (model->GetArray(input_name).quantization_params) + return ::tensorflow::Status::OK(); + if (!IsDiscardableArray(*model, input_name)) + return ::tensorflow::Status::OK(); } const int concatenation_axis = concat_op->axis; @@ -205,7 +212,8 @@ bool ResolveConstantConcatenation::Run(Model* model, std::size_t op_index) { // Remove concatenate operator. model->operators.erase(concat_it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco |