aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc36
1 files changed, 20 insertions, 16 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc
index b324631579..b8da756d85 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc
@@ -188,14 +188,17 @@ void FuseMulOrDivParamsIntoPrecedingAffine(Model* model, Operator* preceding_op,
}
} // namespace
-bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status FuseBinaryIntoPrecedingAffine::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto binary_it = model->operators.begin() + op_index;
const auto* binary_op = binary_it->get();
if (binary_op->type != OperatorType::kAdd &&
binary_op->type != OperatorType::kMul &&
binary_op->type != OperatorType::kSub &&
binary_op->type != OperatorType::kDiv) {
- return false;
+ return ::tensorflow::Status::OK();
}
CHECK_EQ(binary_op->inputs.size(), 2);
@@ -213,12 +216,12 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
};
if (!is_input_constant[0] && !is_input_constant[1]) {
// Neither input is constant, so nothing we can fuse into a constant.
- return false;
+ return ::tensorflow::Status::OK();
}
if (is_input_constant[0] && is_input_constant[1]) {
// Both inputs are constants. That's a job for constants
// propagation, not for us to handle here.
- return false;
+ return ::tensorflow::Status::OK();
}
const int index_of_constant_input = is_input_constant[0] ? 0 : 1;
const int index_of_variable_input = is_input_constant[0] ? 1 : 0;
@@ -230,7 +233,7 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
if (index_of_constant_input != 1) {
AddMessageF("Not fusing %s because the denominator is not constant",
LogName(*binary_op));
- return false;
+ return ::tensorflow::Status::OK();
}
}
@@ -239,12 +242,12 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
if (!preceding_op) {
AddMessageF("Not fusing %s because it is not the output of another op",
LogName(*binary_op));
- return false;
+ return ::tensorflow::Status::OK();
}
for (const string& output_array : model->flags.output_arrays()) {
if (preceding_op->outputs[0] == output_array) {
- return false;
+ return ::tensorflow::Status::OK();
}
}
@@ -255,7 +258,7 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
"Not fusing %s because the preceding %s is not of one of the supported "
"types",
LogName(*binary_op), LogName(*preceding_op));
- return false;
+ return ::tensorflow::Status::OK();
}
if (preceding_op->fused_activation_function !=
@@ -264,14 +267,14 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
"Not fusing %s because the preceding %s has a fused activation "
"function",
LogName(*binary_op), LogName(*preceding_op));
- return false;
+ return ::tensorflow::Status::OK();
}
if (preceding_op->inputs.size() < 3) {
AddMessageF(
"Not fusing %s because the preceding %s does not have a bias vector",
LogName(*binary_op), LogName(*preceding_op));
- return false;
+ return ::tensorflow::Status::OK();
}
const auto& weights_name = preceding_op->inputs[1];
@@ -289,14 +292,14 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
"Not fusing %s because the preceding %s has a non-constant bias "
"array",
LogName(*binary_op), LogName(*preceding_op));
- return false;
+ return ::tensorflow::Status::OK();
}
if (count_ops_consuming_bias > 1) {
AddMessageF(
"Not fusing %s because the bias of the preceding %s is consumed by "
"another op",
LogName(*binary_op), LogName(*preceding_op));
- return false;
+ return ::tensorflow::Status::OK();
}
} else {
if (!weights.buffer || !bias.buffer) {
@@ -304,14 +307,14 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
"Not fusing %s because the preceding %s has non-constant weights or "
"bias arrays",
LogName(*binary_op), LogName(*preceding_op));
- return false;
+ return ::tensorflow::Status::OK();
}
if (count_ops_consuming_weights > 1 || count_ops_consuming_bias > 1) {
AddMessageF(
"Not fusing %s because the weights or bias of the preceding %s is "
"consumed by another op",
LogName(*binary_op), LogName(*preceding_op));
- return false;
+ return ::tensorflow::Status::OK();
}
}
@@ -323,7 +326,7 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
"Not fusing %s because the output of the preceding %s is consumed by "
"another op",
LogName(*binary_op), LogName(*preceding_op));
- return false;
+ return ::tensorflow::Status::OK();
}
AddMessageF("Fusing %s into the preceding %s", LogName(*binary_op),
@@ -352,7 +355,8 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
model->EraseArray(old_constant_param_name);
}
model->operators.erase(binary_it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco