aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc16
1 files changed, 10 insertions, 6 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc
index 874d8def57..4848867b9a 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc
@@ -51,19 +51,22 @@ bool IsBroadcastingOp(const Model& model, Operator* op) {
// Finds an operation that looks like a broadcast (concat of the same sources
// along the last dimension) and drops it by relying on the ability of certain
// binary ops to perform an implicit broadcast.
-bool FuseBroadcastIntoFollowingBinary::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status FuseBroadcastIntoFollowingBinary::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto binary_it = model->operators.begin() + op_index;
auto* binary_op = binary_it->get();
// Test for binary ops of types that we know how to resolve
if (binary_op->inputs.size() != 2) {
- return false;
+ return ::tensorflow::Status::OK();
}
if (binary_op->type != OperatorType::kAdd &&
binary_op->type != OperatorType::kMul &&
binary_op->type != OperatorType::kSub &&
binary_op->type != OperatorType::kDiv) {
- return false;
+ return ::tensorflow::Status::OK();
}
// NOTE: either of these ops may be nullptr if the input array is constant.
@@ -78,14 +81,14 @@ bool FuseBroadcastIntoFollowingBinary::Run(Model* model, std::size_t op_index) {
if (!is_op_0_broadcast && !is_op_1_broadcast) {
// Neither input is a broadcast-looking thing.
AddMessageF("Neither input looks broadcasty");
- return false;
+ return ::tensorflow::Status::OK();
} else if (is_op_0_broadcast && is_op_1_broadcast) {
AddMessageF(
"Unable to fuse broadcast into %s as both inputs (%s, %s) are "
"broadcasts",
LogName(*binary_op), op[0] ? LogName(*op[0]) : "(?)",
op[1] ? LogName(*op[1]) : "(?)");
- return false;
+ return ::tensorflow::Status::OK();
}
int broadcast_index = is_op_0_broadcast ? 0 : 1;
@@ -96,7 +99,8 @@ bool FuseBroadcastIntoFollowingBinary::Run(Model* model, std::size_t op_index) {
binary_op->inputs[broadcast_index] = op[broadcast_index]->inputs[0];
// We leave the broadcast op in; it'll get cleaned up if it's not used later.
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco