aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc26
1 files changed, 14 insertions, 12 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc b/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc
index 95bc7f7d4b..06de9b1cd8 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc
@@ -102,18 +102,19 @@ std::vector<int32> ReshapeToTranspose(const Model& model,
// to be merged if the reshape does not affect memory ordering and does not
// affects the number of dimensions. This only occurs when only unary dimensions
// are shifting position.
-bool MergeReshapeIntoPrecedingTranspose::Run(Model* model,
- std::size_t op_index) {
+::tensorflow::Status MergeReshapeIntoPrecedingTranspose::Run(
+ Model* model, std::size_t op_index, bool* modified) {
+ *modified = false;
auto it = model->operators.begin() + op_index;
auto* reshape_op = ConvertOperator<TensorFlowReshapeOperator*>(
it->get(), OperatorType::kReshape);
if (reshape_op == nullptr) {
- return false;
+ return ::tensorflow::Status::OK();
}
if (!OperatorReady(*model, reshape_op) || reshape_op->shape.empty()) {
- return false;
+ return ::tensorflow::Status::OK();
}
const string intermediate_name = reshape_op->inputs[0];
@@ -121,13 +122,13 @@ bool MergeReshapeIntoPrecedingTranspose::Run(Model* model,
// Guarantee the input is only consume by the reshape.
if (CountOpsWithInput(*model, intermediate_name) != 1) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Check for the parent operator.
const auto& transpose_it = FindOpWithOutput(*model, intermediate_name);
if (transpose_it == model->operators.end()) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Find the parent operator and guarantee it is a transpose.
@@ -135,16 +136,16 @@ bool MergeReshapeIntoPrecedingTranspose::Run(Model* model,
transpose_it->get(), OperatorType::kTranspose);
if (transpose_op == nullptr) {
- return false;
+ return ::tensorflow::Status::OK();
}
if (!OperatorReady(*model, transpose_op) || transpose_op->perm.empty()) {
- return false;
+ return ::tensorflow::Status::OK();
}
if (!ReshapeIsEquivalentToTranspose(*model, reshape_op,
false /*allow_extra_unary_dimensions*/)) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Check that the intermediate is not an output array.
@@ -153,7 +154,7 @@ bool MergeReshapeIntoPrecedingTranspose::Run(Model* model,
"Cannot fuse %s and %s as it would invalidate the transpose "
"output array.",
LogName(*transpose_op), LogName(*reshape_op));
- return false;
+ return ::tensorflow::Status::OK();
}
AddMessageF("Merging operations %s and %s", LogName(*transpose_op),
@@ -172,7 +173,7 @@ bool MergeReshapeIntoPrecedingTranspose::Run(Model* model,
// Remove the reshape as passthrough operation.
if (!RemoveTrivialPassthroughOp(this, model, op_index)) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Update transpose_op's constant buffer to contain the new permutation.
@@ -184,7 +185,8 @@ bool MergeReshapeIntoPrecedingTranspose::Run(Model* model,
// transpose_ops's shape will likely has changed.
model->GetArray(transpose_op->outputs[0]).clear_shape();
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco