diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc | 16 |
1 files changed, 10 insertions, 6 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc index 874d8def57..4848867b9a 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc @@ -51,19 +51,22 @@ bool IsBroadcastingOp(const Model& model, Operator* op) { // Finds an operation that looks like a broadcast (concat of the same sources // along the last dimension) and drops it by relying on the ability of certain // binary ops to perform an implicit broadcast. -bool FuseBroadcastIntoFollowingBinary::Run(Model* model, std::size_t op_index) { +::tensorflow::Status FuseBroadcastIntoFollowingBinary::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto binary_it = model->operators.begin() + op_index; auto* binary_op = binary_it->get(); // Test for binary ops of types that we know how to resolve if (binary_op->inputs.size() != 2) { - return false; + return ::tensorflow::Status::OK(); } if (binary_op->type != OperatorType::kAdd && binary_op->type != OperatorType::kMul && binary_op->type != OperatorType::kSub && binary_op->type != OperatorType::kDiv) { - return false; + return ::tensorflow::Status::OK(); } // NOTE: either of these ops may be nullptr if the input array is constant. @@ -78,14 +81,14 @@ bool FuseBroadcastIntoFollowingBinary::Run(Model* model, std::size_t op_index) { if (!is_op_0_broadcast && !is_op_1_broadcast) { // Neither input is a broadcast-looking thing. AddMessageF("Neither input looks broadcasty"); - return false; + return ::tensorflow::Status::OK(); } else if (is_op_0_broadcast && is_op_1_broadcast) { AddMessageF( "Unable to fuse broadcast into %s as both inputs (%s, %s) are " "broadcasts", LogName(*binary_op), op[0] ? LogName(*op[0]) : "(?)", op[1] ? LogName(*op[1]) : "(?)"); - return false; + return ::tensorflow::Status::OK(); } int broadcast_index = is_op_0_broadcast ? 0 : 1; @@ -96,7 +99,8 @@ bool FuseBroadcastIntoFollowingBinary::Run(Model* model, std::size_t op_index) { binary_op->inputs[broadcast_index] = op[broadcast_index]->inputs[0]; // We leave the broadcast op in; it'll get cleaned up if it's not used later. - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco |