aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc16
1 files changed, 10 insertions, 6 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc
index 5a36a90b38..e5a96d4335 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc
@@ -48,10 +48,13 @@ bool TransposeAffectsMemoryOrder(std::vector<int> perm,
} // namespace
-bool ConvertTrivialTransposeToReshape::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ConvertTrivialTransposeToReshape::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
auto transpose_it = model->operators.begin() + op_index;
if (transpose_it->get()->type != OperatorType::kTranspose) {
- return false;
+ return ::tensorflow::Status::OK();
}
TransposeOperator* transpose_op =
static_cast<TransposeOperator*>(transpose_it->get());
@@ -60,14 +63,14 @@ bool ConvertTrivialTransposeToReshape::Run(Model* model, std::size_t op_index) {
const auto& output_array = model->GetArray(transpose_op->outputs[0]);
if (!input_array.has_shape() || !output_array.has_shape()) {
// Yield until PropagateFixedSizes has been run on this op.
- return false;
+ return ::tensorflow::Status::OK();
}
// Note: We can assume we have error checked inputs in PropagateFixedSizes.
// Check that the permutation has propogated.
std::vector<int> const& perm = transpose_op->perm;
if (perm.empty()) {
- return false;
+ return ::tensorflow::Status::OK();
}
// This transpose is trivial if non-unitary dimensions remain in the same
@@ -76,7 +79,7 @@ bool ConvertTrivialTransposeToReshape::Run(Model* model, std::size_t op_index) {
std::vector<int> const& output_dims = output_array.shape().dims();
if (TransposeAffectsMemoryOrder(perm, input_dims)) {
- return false;
+ return ::tensorflow::Status::OK();
}
// This transpose is trivial. Replace it with a Reshape op.
@@ -109,7 +112,8 @@ bool ConvertTrivialTransposeToReshape::Run(Model* model, std::size_t op_index) {
CHECK_EQ(transpose_it->get(), transpose_op);
model->operators.erase(transpose_it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco