aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_pack_to_reshape.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_pack_to_reshape.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_pack_to_reshape.cc16
1 files changed, 10 insertions, 6 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_pack_to_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_pack_to_reshape.cc
index 75113a2a8c..78779243a9 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_pack_to_reshape.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_pack_to_reshape.cc
@@ -25,27 +25,30 @@ limitations under the License.
namespace toco {
-bool ConvertTrivialPackToReshape::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ConvertTrivialPackToReshape::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
auto pack_it = model->operators.begin() + op_index;
if (pack_it->get()->type != OperatorType::kPack) {
- return false;
+ return ::tensorflow::Status::OK();
}
auto* pack_op = static_cast<PackOperator*>(pack_it->get());
if (pack_op->inputs.size() > 1) {
// Not trivial.
- return false;
+ return ::tensorflow::Status::OK();
}
CHECK_EQ(pack_op->outputs.size(), 1);
const auto& input_array = model->GetArray(pack_op->inputs[0]);
if (!input_array.has_shape()) {
// Yield until input dims have been resolved.
- return false;
+ return ::tensorflow::Status::OK();
}
if (input_array.shape().dimensions_count() == 0) {
// Input array cannot be 0-D.
// (Unsure if this is TF behavior, but was required to get a test to pass.)
- return false;
+ return ::tensorflow::Status::OK();
}
AddMessageF("Converting trivial %s to a reshape", LogName(*pack_op));
@@ -75,7 +78,8 @@ bool ConvertTrivialPackToReshape::Run(Model* model, std::size_t op_index) {
CHECK_EQ(pack_it->get(), pack_op);
model->operators.erase(pack_it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco