diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc | 16 |
1 files changed, 10 insertions, 6 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc index b689be0792..b6d712ca44 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc @@ -21,10 +21,13 @@ limitations under the License. namespace toco { -bool ConvertTrivialTileToConcat::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ConvertTrivialTileToConcat::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; auto tile_it = model->operators.begin() + op_index; if (tile_it->get()->type != OperatorType::kTile) { - return false; + return ::tensorflow::Status::OK(); } auto* tile_op = static_cast<TransposeOperator*>(tile_it->get()); @@ -34,13 +37,13 @@ bool ConvertTrivialTileToConcat::Run(Model* model, std::size_t op_index) { if (!input_array.has_shape() || !multiples_array.has_shape() || !output_array.has_shape()) { // Yield until PropagateFixedSizes has been run on this op. - return false; + return ::tensorflow::Status::OK(); } // Note: We can assume we have error checked inputs in PropagateFixedSizes. if (!multiples_array.buffer) { // Yield until the multiples is constant. - return false; + return ::tensorflow::Status::OK(); } std::vector<int32> const& multiples = multiples_array.GetBuffer<ArrayDataType::kInt32>().data; @@ -59,7 +62,7 @@ bool ConvertTrivialTileToConcat::Run(Model* model, std::size_t op_index) { // The tile is non-trivial. Good luck. AddMessageF("Tile %s is non-trivial (has more than one multiply dimension)", LogName(*tile_op)); - return false; + return ::tensorflow::Status::OK(); } // The tile is like a concat. @@ -88,7 +91,8 @@ bool ConvertTrivialTileToConcat::Run(Model* model, std::size_t op_index) { CHECK_EQ(tile_it->get(), tile_op); model->operators.erase(tile_it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco |