aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc18
1 files changed, 11 insertions, 7 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc
index 81cedb5dad..a0bd1ed4a4 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc
@@ -30,10 +30,13 @@ namespace toco {
// means that the data layout will never change with this op, just the shape.
// By converting these to reshapes once we have run shape propagation we allow
// standard reshape optimization transforms to do their magic.
-bool ConvertSqueezeToReshape::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ConvertSqueezeToReshape::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
auto squeeze_it = model->operators.begin() + op_index;
if (squeeze_it->get()->type != OperatorType::kSqueeze) {
- return false;
+ return ::tensorflow::Status::OK();
}
auto squeeze_op = static_cast<SqueezeOperator*>(squeeze_it->get());
CHECK_EQ(squeeze_op->inputs.size(), 1);
@@ -42,16 +45,16 @@ bool ConvertSqueezeToReshape::Run(Model* model, std::size_t op_index) {
const auto& input_array = model->GetArray(squeeze_op->inputs[0]);
if (!input_array.has_shape()) {
// Yield until input dims have been resolved.
- return false;
+ return ::tensorflow::Status::OK();
}
if (input_array.shape().dimensions_count() == 0) {
// Input array cannot be 0-D.
- return false;
+ return ::tensorflow::Status::OK();
}
if (!model->HasArray(squeeze_op->outputs[0]) ||
!model->GetArray(squeeze_op->outputs[0]).has_shape()) {
// Yield until shape propagation has set the output shape for us.
- return false;
+ return ::tensorflow::Status::OK();
}
// We use the output shape that has been calculated by shape propagation.
@@ -59,7 +62,7 @@ bool ConvertSqueezeToReshape::Run(Model* model, std::size_t op_index) {
// Empty shapes will not work as empty data arrays.
if (output_shape.dimensions_count() == 0) {
- return false;
+ return ::tensorflow::Status::OK();
}
auto* reshape_op = new TensorFlowReshapeOperator;
@@ -79,7 +82,8 @@ bool ConvertSqueezeToReshape::Run(Model* model, std::size_t op_index) {
CHECK_EQ(squeeze_it->get(), squeeze_op);
model->operators.erase(squeeze_it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco