diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc | 36 |
1 files changed, 20 insertions, 16 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc index b324631579..b8da756d85 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc @@ -188,14 +188,17 @@ void FuseMulOrDivParamsIntoPrecedingAffine(Model* model, Operator* preceding_op, } } // namespace -bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) { +::tensorflow::Status FuseBinaryIntoPrecedingAffine::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto binary_it = model->operators.begin() + op_index; const auto* binary_op = binary_it->get(); if (binary_op->type != OperatorType::kAdd && binary_op->type != OperatorType::kMul && binary_op->type != OperatorType::kSub && binary_op->type != OperatorType::kDiv) { - return false; + return ::tensorflow::Status::OK(); } CHECK_EQ(binary_op->inputs.size(), 2); @@ -213,12 +216,12 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) { }; if (!is_input_constant[0] && !is_input_constant[1]) { // Neither input is constant, so nothing we can fuse into a constant. - return false; + return ::tensorflow::Status::OK(); } if (is_input_constant[0] && is_input_constant[1]) { // Both inputs are constants. That's a job for constants // propagation, not for us to handle here. - return false; + return ::tensorflow::Status::OK(); } const int index_of_constant_input = is_input_constant[0] ? 0 : 1; const int index_of_variable_input = is_input_constant[0] ? 1 : 0; @@ -230,7 +233,7 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) { if (index_of_constant_input != 1) { AddMessageF("Not fusing %s because the denominator is not constant", LogName(*binary_op)); - return false; + return ::tensorflow::Status::OK(); } } @@ -239,12 +242,12 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) { if (!preceding_op) { AddMessageF("Not fusing %s because it is not the output of another op", LogName(*binary_op)); - return false; + return ::tensorflow::Status::OK(); } for (const string& output_array : model->flags.output_arrays()) { if (preceding_op->outputs[0] == output_array) { - return false; + return ::tensorflow::Status::OK(); } } @@ -255,7 +258,7 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) { "Not fusing %s because the preceding %s is not of one of the supported " "types", LogName(*binary_op), LogName(*preceding_op)); - return false; + return ::tensorflow::Status::OK(); } if (preceding_op->fused_activation_function != @@ -264,14 +267,14 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) { "Not fusing %s because the preceding %s has a fused activation " "function", LogName(*binary_op), LogName(*preceding_op)); - return false; + return ::tensorflow::Status::OK(); } if (preceding_op->inputs.size() < 3) { AddMessageF( "Not fusing %s because the preceding %s does not have a bias vector", LogName(*binary_op), LogName(*preceding_op)); - return false; + return ::tensorflow::Status::OK(); } const auto& weights_name = preceding_op->inputs[1]; @@ -289,14 +292,14 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) { "Not fusing %s because the preceding %s has a non-constant bias " "array", LogName(*binary_op), LogName(*preceding_op)); - return false; + return ::tensorflow::Status::OK(); } if (count_ops_consuming_bias > 1) { AddMessageF( "Not fusing %s because the bias of the preceding %s is consumed by " "another op", LogName(*binary_op), LogName(*preceding_op)); - return false; + return ::tensorflow::Status::OK(); } } else { if (!weights.buffer || !bias.buffer) { @@ -304,14 +307,14 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) { "Not fusing %s because the preceding %s has non-constant weights or " "bias arrays", LogName(*binary_op), LogName(*preceding_op)); - return false; + return ::tensorflow::Status::OK(); } if (count_ops_consuming_weights > 1 || count_ops_consuming_bias > 1) { AddMessageF( "Not fusing %s because the weights or bias of the preceding %s is " "consumed by another op", LogName(*binary_op), LogName(*preceding_op)); - return false; + return ::tensorflow::Status::OK(); } } @@ -323,7 +326,7 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) { "Not fusing %s because the output of the preceding %s is consumed by " "another op", LogName(*binary_op), LogName(*preceding_op)); - return false; + return ::tensorflow::Status::OK(); } AddMessageF("Fusing %s into the preceding %s", LogName(*binary_op), @@ -352,7 +355,8 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) { model->EraseArray(old_constant_param_name); } model->operators.erase(binary_it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco |