aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc32
1 files changed, 18 insertions, 14 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc
index dcbbead517..0de22b8ff4 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc
@@ -150,14 +150,17 @@ void FuseMulOrDivParamsIntoFollowingAffine(Model* model, Operator* following_op,
} // namespace
-bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status FuseBinaryIntoFollowingAffine::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto binary_it = model->operators.begin() + op_index;
auto* binary_op = binary_it->get();
if (binary_op->type != OperatorType::kAdd &&
binary_op->type != OperatorType::kMul &&
binary_op->type != OperatorType::kSub &&
binary_op->type != OperatorType::kDiv) {
- return false;
+ return ::tensorflow::Status::OK();
}
CHECK_EQ(binary_op->inputs.size(), 2);
@@ -175,12 +178,12 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
};
if (!is_input_constant[0] && !is_input_constant[1]) {
// Neither input is constant, so nothing we can fuse into a constant.
- return false;
+ return ::tensorflow::Status::OK();
}
if (is_input_constant[0] && is_input_constant[1]) {
// Both inputs are constants. That's a job for constants
// propagation, not for us to handle here.
- return false;
+ return ::tensorflow::Status::OK();
}
const int index_of_constant_input = is_input_constant[0] ? 0 : 1;
const int index_of_variable_input = is_input_constant[0] ? 1 : 0;
@@ -192,7 +195,7 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
if (index_of_constant_input != 1) {
AddMessageF("Not fusing %s because the denominator is not constant",
LogName(*binary_op));
- return false;
+ return ::tensorflow::Status::OK();
}
}
@@ -204,7 +207,7 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
"Not fusing %s into the following affine op, because we only know "
"how to do so when the constant operand is a scalar",
LogName(*binary_op));
- return false;
+ return ::tensorflow::Status::OK();
}
}
@@ -212,7 +215,7 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
FusedActivationFunctionType::kNone) {
AddMessageF("Not fusing %s because it has a fused activation function",
LogName(*binary_op));
- return false;
+ return ::tensorflow::Status::OK();
}
Operator* following_op = GetOpWithInput(*model, binary_op->outputs[0]);
@@ -221,7 +224,7 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
AddMessageF(
"Not fusing %s because it is not consumed by exactly one other op",
LogName(*binary_op));
- return false;
+ return ::tensorflow::Status::OK();
}
if (following_op->type != OperatorType::kConv &&
@@ -231,14 +234,14 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
"Not fusing %s because the following %s is not of one of the supported "
"types",
LogName(*binary_op), LogName(*following_op));
- return false;
+ return ::tensorflow::Status::OK();
}
if (following_op->inputs.size() < 3) {
AddMessageF(
"Not fusing %s because the following %s does not have a bias vector",
LogName(*following_op), LogName(*binary_op));
- return false;
+ return ::tensorflow::Status::OK();
}
const auto& weights = model->GetArray(following_op->inputs[1]);
@@ -248,7 +251,7 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
"Not fusing %s because the following %s has non-constant weights or "
"bias arrays",
LogName(*binary_op), LogName(*following_op));
- return false;
+ return ::tensorflow::Status::OK();
}
// Try to fuse the binary params into the following op's params
@@ -260,7 +263,7 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
AddMessageF(
"Not fusing %s because the following %s does not use VALID padding",
LogName(*binary_op), LogName(*following_op));
- return false;
+ return ::tensorflow::Status::OK();
}
}
if (following_op->type == OperatorType::kDepthwiseConv) {
@@ -269,7 +272,7 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
AddMessageF(
"Not fusing %s because the following %s does not use VALID padding",
LogName(*binary_op), LogName(*following_op));
- return false;
+ return ::tensorflow::Status::OK();
}
}
FuseAddOrSubParamsIntoFollowingAffine(model, following_op, binary_op,
@@ -294,7 +297,8 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
model->EraseArray(old_constant_param_name);
}
model->operators.erase(binary_it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco