diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc | 26 |
1 files changed, 14 insertions, 12 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc b/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc index 95bc7f7d4b..06de9b1cd8 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc @@ -102,18 +102,19 @@ std::vector<int32> ReshapeToTranspose(const Model& model, // to be merged if the reshape does not affect memory ordering and does not // affects the number of dimensions. This only occurs when only unary dimensions // are shifting position. -bool MergeReshapeIntoPrecedingTranspose::Run(Model* model, - std::size_t op_index) { +::tensorflow::Status MergeReshapeIntoPrecedingTranspose::Run( + Model* model, std::size_t op_index, bool* modified) { + *modified = false; auto it = model->operators.begin() + op_index; auto* reshape_op = ConvertOperator<TensorFlowReshapeOperator*>( it->get(), OperatorType::kReshape); if (reshape_op == nullptr) { - return false; + return ::tensorflow::Status::OK(); } if (!OperatorReady(*model, reshape_op) || reshape_op->shape.empty()) { - return false; + return ::tensorflow::Status::OK(); } const string intermediate_name = reshape_op->inputs[0]; @@ -121,13 +122,13 @@ bool MergeReshapeIntoPrecedingTranspose::Run(Model* model, // Guarantee the input is only consume by the reshape. if (CountOpsWithInput(*model, intermediate_name) != 1) { - return false; + return ::tensorflow::Status::OK(); } // Check for the parent operator. const auto& transpose_it = FindOpWithOutput(*model, intermediate_name); if (transpose_it == model->operators.end()) { - return false; + return ::tensorflow::Status::OK(); } // Find the parent operator and guarantee it is a transpose. @@ -135,16 +136,16 @@ bool MergeReshapeIntoPrecedingTranspose::Run(Model* model, transpose_it->get(), OperatorType::kTranspose); if (transpose_op == nullptr) { - return false; + return ::tensorflow::Status::OK(); } if (!OperatorReady(*model, transpose_op) || transpose_op->perm.empty()) { - return false; + return ::tensorflow::Status::OK(); } if (!ReshapeIsEquivalentToTranspose(*model, reshape_op, false /*allow_extra_unary_dimensions*/)) { - return false; + return ::tensorflow::Status::OK(); } // Check that the intermediate is not an output array. @@ -153,7 +154,7 @@ bool MergeReshapeIntoPrecedingTranspose::Run(Model* model, "Cannot fuse %s and %s as it would invalidate the transpose " "output array.", LogName(*transpose_op), LogName(*reshape_op)); - return false; + return ::tensorflow::Status::OK(); } AddMessageF("Merging operations %s and %s", LogName(*transpose_op), @@ -172,7 +173,7 @@ bool MergeReshapeIntoPrecedingTranspose::Run(Model* model, // Remove the reshape as passthrough operation. if (!RemoveTrivialPassthroughOp(this, model, op_index)) { - return false; + return ::tensorflow::Status::OK(); } // Update transpose_op's constant buffer to contain the new permutation. @@ -184,7 +185,8 @@ bool MergeReshapeIntoPrecedingTranspose::Run(Model* model, // transpose_ops's shape will likely has changed. model->GetArray(transpose_op->outputs[0]).clear_shape(); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco |