diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc | 16 |
1 files changed, 10 insertions, 6 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc index 8a0e3e8995..a1756a8207 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc @@ -19,29 +19,32 @@ limitations under the License. namespace toco { -bool ResolveConstantShapeOrRank::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveConstantShapeOrRank::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto it = model->operators.begin() + op_index; const auto* op = it->get(); if (!(op->type == OperatorType::kShape || op->type == OperatorType::kRank)) { - return false; + return ::tensorflow::Status::OK(); } CHECK_EQ(op->outputs.size(), 1); auto& output_array = model->GetArray(op->outputs[0]); if (output_array.data_type == ArrayDataType::kNone) { // Yield until the output type has been resolved - return false; + return ::tensorflow::Status::OK(); } const auto& input_array = model->GetArray(op->inputs[0]); if (!input_array.has_shape()) { // Yield until the input array's shape has been resolved. - return false; + return ::tensorflow::Status::OK(); } if (!output_array.has_shape()) { // Yield until the output shape has been resolved. - return false; + return ::tensorflow::Status::OK(); } // Compute the output @@ -65,7 +68,8 @@ bool ResolveConstantShapeOrRank::Run(Model* model, std::size_t op_index) { } model->operators.erase(it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco |