diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc | 16 |
1 files changed, 10 insertions, 6 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc index 310a88484c..8a945ac435 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc @@ -25,10 +25,13 @@ limitations under the License. namespace toco { -bool ConvertExpandDimsToReshape::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ConvertExpandDimsToReshape::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; auto expand_it = model->operators.begin() + op_index; if (expand_it->get()->type != OperatorType::kExpandDims) { - return false; + return ::tensorflow::Status::OK(); } ExpandDimsOperator* expand_op = static_cast<ExpandDimsOperator*>(expand_it->get()); @@ -38,18 +41,18 @@ bool ConvertExpandDimsToReshape::Run(Model* model, std::size_t op_index) { const auto& input_array = model->GetArray(expand_op->inputs[0]); if (!input_array.has_shape()) { // Yield until input dims have been resolved. - return false; + return ::tensorflow::Status::OK(); } const auto& axis_array = model->GetArray(expand_op->inputs[1]); if (!axis_array.has_shape()) { // Yield until input axis array shape has been resolved. - return false; + return ::tensorflow::Status::OK(); } CHECK_EQ(RequiredBufferSizeForShape(axis_array.shape()), 1); if (!axis_array.buffer) { // Yield until the input axis array is constant - return false; + return ::tensorflow::Status::OK(); } int axis = axis_array.GetBuffer<ArrayDataType::kInt32>().data[0]; std::vector<int> reshape_dims(input_array.shape().dims()); @@ -90,7 +93,8 @@ bool ConvertExpandDimsToReshape::Run(Model* model, std::size_t op_index) { CHECK_EQ(expand_it->get(), expand_op); model->operators.erase(expand_it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco |