diff options
author | Yu-Cheng Ling <ycling@google.com> | 2018-10-09 11:38:15 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-09 11:48:46 -0700 |
commit | 12e164d1e7c0b197f06d5d3c2ed26318b89b5e4c (patch) | |
tree | d2f0b6ba463baff8e3607575f41d3655762f3d14 /tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc | |
parent | 931353c5f79c2d419afb3a5ecac59184c5558351 (diff) |
Return ::tensorflow::Status in Toco Graph Transformations.
PiperOrigin-RevId: 216392908
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc | 24 |
1 files changed, 16 insertions, 8 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc index d916ae0ddf..0c60fdfeb3 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc @@ -135,11 +135,14 @@ void SetMinMaxForConcatenedArray(GraphTransformation* transformation, } // namespace // Resolves the concatenation operator if all its inputs are constant arrays. -bool ResolveConstantConcatenation::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveConstantConcatenation::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto concat_it = model->operators.begin() + op_index; const auto* concat_base_op = concat_it->get(); if (concat_base_op->type != OperatorType::kConcatenation) { - return false; + return ::tensorflow::Status::OK(); } const auto* concat_op = static_cast<const ConcatenationOperator*>(concat_base_op); @@ -149,11 +152,15 @@ bool ResolveConstantConcatenation::Run(Model* model, std::size_t op_index) { // We also make sure the shapes of the input arrays are known and they are // all discardable. const Operator* input_op = GetOpWithOutput(*model, input_name); - if (input_op) return false; - if (!IsConstantParameterArray(*model, input_name)) return false; - if (!model->GetArray(input_name).has_shape()) return false; - if (model->GetArray(input_name).quantization_params) return false; - if (!IsDiscardableArray(*model, input_name)) return false; + if (input_op) return ::tensorflow::Status::OK(); + if (!IsConstantParameterArray(*model, input_name)) + return ::tensorflow::Status::OK(); + if (!model->GetArray(input_name).has_shape()) + return ::tensorflow::Status::OK(); + if (model->GetArray(input_name).quantization_params) + return ::tensorflow::Status::OK(); + if (!IsDiscardableArray(*model, input_name)) + return ::tensorflow::Status::OK(); } const int concatenation_axis = concat_op->axis; @@ -205,7 +212,8 @@ bool ResolveConstantConcatenation::Run(Model* model, std::size_t op_index) { // Remove concatenate operator. model->operators.erase(concat_it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco |