diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc | 16 |
1 files changed, 10 insertions, 6 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc index 5a36a90b38..e5a96d4335 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc @@ -48,10 +48,13 @@ bool TransposeAffectsMemoryOrder(std::vector<int> perm, } // namespace -bool ConvertTrivialTransposeToReshape::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ConvertTrivialTransposeToReshape::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; auto transpose_it = model->operators.begin() + op_index; if (transpose_it->get()->type != OperatorType::kTranspose) { - return false; + return ::tensorflow::Status::OK(); } TransposeOperator* transpose_op = static_cast<TransposeOperator*>(transpose_it->get()); @@ -60,14 +63,14 @@ bool ConvertTrivialTransposeToReshape::Run(Model* model, std::size_t op_index) { const auto& output_array = model->GetArray(transpose_op->outputs[0]); if (!input_array.has_shape() || !output_array.has_shape()) { // Yield until PropagateFixedSizes has been run on this op. - return false; + return ::tensorflow::Status::OK(); } // Note: We can assume we have error checked inputs in PropagateFixedSizes. // Check that the permutation has propogated. std::vector<int> const& perm = transpose_op->perm; if (perm.empty()) { - return false; + return ::tensorflow::Status::OK(); } // This transpose is trivial if non-unitary dimensions remain in the same @@ -76,7 +79,7 @@ bool ConvertTrivialTransposeToReshape::Run(Model* model, std::size_t op_index) { std::vector<int> const& output_dims = output_array.shape().dims(); if (TransposeAffectsMemoryOrder(perm, input_dims)) { - return false; + return ::tensorflow::Status::OK(); } // This transpose is trivial. Replace it with a Reshape op. @@ -109,7 +112,8 @@ bool ConvertTrivialTransposeToReshape::Run(Model* model, std::size_t op_index) { CHECK_EQ(transpose_it->get(), transpose_op); model->operators.erase(transpose_it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco |