aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc24
1 files changed, 22 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc
index 76c6be00d4..b324631579 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc
@@ -274,8 +274,14 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
return false;
}
- const auto& weights = model->GetArray(preceding_op->inputs[1]);
- const auto& bias = model->GetArray(preceding_op->inputs[2]);
+ const auto& weights_name = preceding_op->inputs[1];
+ const auto& bias_name = preceding_op->inputs[2];
+ const auto& weights = model->GetArray(weights_name);
+ const auto& bias = model->GetArray(bias_name);
+ const int count_ops_consuming_bias = CountOpsWithInput(*model, bias_name);
+ const int count_ops_consuming_weights =
+ CountOpsWithInput(*model, weights_name);
+
if (binary_op->type == OperatorType::kAdd ||
binary_op->type == OperatorType::kSub) {
if (!bias.buffer) {
@@ -285,6 +291,13 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
LogName(*binary_op), LogName(*preceding_op));
return false;
}
+ if (count_ops_consuming_bias > 1) {
+ AddMessageF(
+ "Not fusing %s because the bias of the preceding %s is consumed by "
+ "another op",
+ LogName(*binary_op), LogName(*preceding_op));
+ return false;
+ }
} else {
if (!weights.buffer || !bias.buffer) {
AddMessageF(
@@ -293,6 +306,13 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
LogName(*binary_op), LogName(*preceding_op));
return false;
}
+ if (count_ops_consuming_weights > 1 || count_ops_consuming_bias > 1) {
+ AddMessageF(
+ "Not fusing %s because the weights or bias of the preceding %s is "
+ "consumed by another op",
+ LogName(*binary_op), LogName(*preceding_op));
+ return false;
+ }
}
int count_ops_consuming_output =