diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc | 18 |
1 files changed, 11 insertions, 7 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc index 81cedb5dad..a0bd1ed4a4 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc @@ -30,10 +30,13 @@ namespace toco { // means that the data layout will never change with this op, just the shape. // By converting these to reshapes once we have run shape propagation we allow // standard reshape optimization transforms to do their magic. -bool ConvertSqueezeToReshape::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ConvertSqueezeToReshape::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; auto squeeze_it = model->operators.begin() + op_index; if (squeeze_it->get()->type != OperatorType::kSqueeze) { - return false; + return ::tensorflow::Status::OK(); } auto squeeze_op = static_cast<SqueezeOperator*>(squeeze_it->get()); CHECK_EQ(squeeze_op->inputs.size(), 1); @@ -42,16 +45,16 @@ bool ConvertSqueezeToReshape::Run(Model* model, std::size_t op_index) { const auto& input_array = model->GetArray(squeeze_op->inputs[0]); if (!input_array.has_shape()) { // Yield until input dims have been resolved. - return false; + return ::tensorflow::Status::OK(); } if (input_array.shape().dimensions_count() == 0) { // Input array cannot be 0-D. - return false; + return ::tensorflow::Status::OK(); } if (!model->HasArray(squeeze_op->outputs[0]) || !model->GetArray(squeeze_op->outputs[0]).has_shape()) { // Yield until shape propagation has set the output shape for us. - return false; + return ::tensorflow::Status::OK(); } // We use the output shape that has been calculated by shape propagation. @@ -59,7 +62,7 @@ bool ConvertSqueezeToReshape::Run(Model* model, std::size_t op_index) { // Empty shapes will not work as empty data arrays. if (output_shape.dimensions_count() == 0) { - return false; + return ::tensorflow::Status::OK(); } auto* reshape_op = new TensorFlowReshapeOperator; @@ -79,7 +82,8 @@ bool ConvertSqueezeToReshape::Run(Model* model, std::size_t op_index) { CHECK_EQ(squeeze_it->get(), squeeze_op); model->operators.erase(squeeze_it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco |