diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc | 24 |
1 files changed, 22 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc index 76c6be00d4..b324631579 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc @@ -274,8 +274,14 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) { return false; } - const auto& weights = model->GetArray(preceding_op->inputs[1]); - const auto& bias = model->GetArray(preceding_op->inputs[2]); + const auto& weights_name = preceding_op->inputs[1]; + const auto& bias_name = preceding_op->inputs[2]; + const auto& weights = model->GetArray(weights_name); + const auto& bias = model->GetArray(bias_name); + const int count_ops_consuming_bias = CountOpsWithInput(*model, bias_name); + const int count_ops_consuming_weights = + CountOpsWithInput(*model, weights_name); + if (binary_op->type == OperatorType::kAdd || binary_op->type == OperatorType::kSub) { if (!bias.buffer) { @@ -285,6 +291,13 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) { LogName(*binary_op), LogName(*preceding_op)); return false; } + if (count_ops_consuming_bias > 1) { + AddMessageF( + "Not fusing %s because the bias of the preceding %s is consumed by " + "another op", + LogName(*binary_op), LogName(*preceding_op)); + return false; + } } else { if (!weights.buffer || !bias.buffer) { AddMessageF( @@ -293,6 +306,13 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) { LogName(*binary_op), LogName(*preceding_op)); return false; } + if (count_ops_consuming_weights > 1 || count_ops_consuming_bias > 1) { + AddMessageF( + "Not fusing %s because the weights or bias of the preceding %s is " + "consumed by another op", + LogName(*binary_op), LogName(*preceding_op)); + return false; + } } int count_ops_consuming_output = |