diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc | 22 |
1 files changed, 13 insertions, 9 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc index 0dfdc40e4c..68c6fb65c5 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc @@ -46,14 +46,17 @@ bool AreAllBufferElementsEqualTo(const std::vector<Scalar>& buffer_data, // For example, an Add operator is trivial if // one of its operands is constant 0, a Mul operator is trivial // if one of its operands is constant 1, etc. -bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) { +::tensorflow::Status RemoveTrivialBinaryOperator::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto binary_it = model->operators.begin() + op_index; auto* binary_op = binary_it->get(); if (binary_op->type != OperatorType::kAdd && binary_op->type != OperatorType::kMul && binary_op->type != OperatorType::kSub && binary_op->type != OperatorType::kDiv) { - return false; + return ::tensorflow::Status::OK(); } CHECK_EQ(binary_op->inputs.size(), 2); @@ -66,12 +69,12 @@ bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) { }; if (!is_input_constant[0] && !is_input_constant[1]) { // Neither input is constant, so nothing we can resolve here. - return false; + return ::tensorflow::Status::OK(); } if (is_input_constant[0] && is_input_constant[1]) { // Both inputs are constants. That's a job for constants // propagation, not for us to handle here. - return false; + return ::tensorflow::Status::OK(); } const int index_of_constant_input = is_input_constant[0] ? 0 : 1; const int index_of_variable_input = is_input_constant[0] ? 1 : 0; @@ -84,7 +87,7 @@ bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) { const auto& input_array_1 = model->GetArray(binary_op->inputs[1]); if (!input_array_0.has_shape() || !input_array_1.has_shape()) { // Both input shapes must be known. - return false; + return ::tensorflow::Status::OK(); } if (input_array_0.shape().dimensions_count() == input_array_1.shape().dimensions_count() && @@ -94,7 +97,7 @@ bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) { "(lhs %s, rhs %s)", LogName(*binary_op), ShapeToString(input_array_0.shape()), ShapeToString(input_array_1.shape())); - return false; + return ::tensorflow::Status::OK(); } // Now check if the constant operand makes this binary @@ -103,7 +106,7 @@ bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) { model->GetArray(binary_op->inputs[index_of_constant_input]); // For now, we only handle floats here. if (constant_input_array.data_type != ArrayDataType::kFloat) { - return false; + return ::tensorflow::Status::OK(); } const auto& constant_input_float_data = constant_input_array.GetBuffer<ArrayDataType::kFloat>().data; @@ -121,12 +124,13 @@ bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) { } if (!is_trivial) { - return false; + return ::tensorflow::Status::OK(); } // Now we know that this node is trivial, so we can remove it. AddMessageF("Removing trivial %s", LogName(*binary_op)); - return RemoveTrivialPassthroughOp(this, model, op_index); + *modified = RemoveTrivialPassthroughOp(this, model, op_index); + return ::tensorflow::Status::OK(); } } // namespace toco |