aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc22
1 files changed, 13 insertions, 9 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc
index 0dfdc40e4c..68c6fb65c5 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc
@@ -46,14 +46,17 @@ bool AreAllBufferElementsEqualTo(const std::vector<Scalar>& buffer_data,
// For example, an Add operator is trivial if
// one of its operands is constant 0, a Mul operator is trivial
// if one of its operands is constant 1, etc.
-bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status RemoveTrivialBinaryOperator::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto binary_it = model->operators.begin() + op_index;
auto* binary_op = binary_it->get();
if (binary_op->type != OperatorType::kAdd &&
binary_op->type != OperatorType::kMul &&
binary_op->type != OperatorType::kSub &&
binary_op->type != OperatorType::kDiv) {
- return false;
+ return ::tensorflow::Status::OK();
}
CHECK_EQ(binary_op->inputs.size(), 2);
@@ -66,12 +69,12 @@ bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) {
};
if (!is_input_constant[0] && !is_input_constant[1]) {
// Neither input is constant, so nothing we can resolve here.
- return false;
+ return ::tensorflow::Status::OK();
}
if (is_input_constant[0] && is_input_constant[1]) {
// Both inputs are constants. That's a job for constants
// propagation, not for us to handle here.
- return false;
+ return ::tensorflow::Status::OK();
}
const int index_of_constant_input = is_input_constant[0] ? 0 : 1;
const int index_of_variable_input = is_input_constant[0] ? 1 : 0;
@@ -84,7 +87,7 @@ bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) {
const auto& input_array_1 = model->GetArray(binary_op->inputs[1]);
if (!input_array_0.has_shape() || !input_array_1.has_shape()) {
// Both input shapes must be known.
- return false;
+ return ::tensorflow::Status::OK();
}
if (input_array_0.shape().dimensions_count() ==
input_array_1.shape().dimensions_count() &&
@@ -94,7 +97,7 @@ bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) {
"(lhs %s, rhs %s)",
LogName(*binary_op), ShapeToString(input_array_0.shape()),
ShapeToString(input_array_1.shape()));
- return false;
+ return ::tensorflow::Status::OK();
}
// Now check if the constant operand makes this binary
@@ -103,7 +106,7 @@ bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) {
model->GetArray(binary_op->inputs[index_of_constant_input]);
// For now, we only handle floats here.
if (constant_input_array.data_type != ArrayDataType::kFloat) {
- return false;
+ return ::tensorflow::Status::OK();
}
const auto& constant_input_float_data =
constant_input_array.GetBuffer<ArrayDataType::kFloat>().data;
@@ -121,12 +124,13 @@ bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) {
}
if (!is_trivial) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Now we know that this node is trivial, so we can remove it.
AddMessageF("Removing trivial %s", LogName(*binary_op));
- return RemoveTrivialPassthroughOp(this, model, op_index);
+ *modified = RemoveTrivialPassthroughOp(this, model, op_index);
+ return ::tensorflow::Status::OK();
}
} // namespace toco