diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc | 24 |
1 files changed, 14 insertions, 10 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc index e88839be5d..a151012891 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc @@ -24,29 +24,32 @@ limitations under the License. namespace toco { -bool ConvertPureConvToDepthwise::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ConvertPureConvToDepthwise::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; auto conv_it = model->operators.begin() + op_index; if (conv_it->get()->type != OperatorType::kConv) { - return false; + return ::tensorflow::Status::OK(); } const auto* conv_op = static_cast<ConvOperator*>(conv_it->get()); if (conv_op->stride_width != conv_op->stride_height) { - return false; + return ::tensorflow::Status::OK(); } if ((conv_op->dilation_width_factor != 1) || (conv_op->dilation_height_factor != 1)) { // Depthwise conv does not support dilation - return false; + return ::tensorflow::Status::OK(); } auto& input_array = model->GetArray(conv_op->inputs[0]); if (!input_array.has_shape()) { // Shapes not propagated yet - return false; + return ::tensorflow::Status::OK(); } if (input_array.shape().dims(3) != 1) { // Not a pure convolution: Conv does accumulation across the depth // dimension. - return false; + return ::tensorflow::Status::OK(); } const auto& weights_name = conv_op->inputs[1]; @@ -56,15 +59,15 @@ bool ConvertPureConvToDepthwise::Run(Model* model, std::size_t op_index) { "Not changing %s to DepthwiseConv because the weights is consumed by " "another op.", LogName(*conv_op)); - return false; + return ::tensorflow::Status::OK(); } auto& weights_array = model->GetArray(weights_name); if (!weights_array.buffer) { // Yield until the weights are resolved as a constant array. - return false; + return ::tensorflow::Status::OK(); } if (weights_array.data_type != ArrayDataType::kFloat) { - return false; + return ::tensorflow::Status::OK(); } // At this point we know we have a pure conv. Rewrite it as DepthwiseConv. AddMessageF( @@ -112,7 +115,8 @@ bool ConvertPureConvToDepthwise::Run(Model* model, std::size_t op_index) { } *weights_array.mutable_shape()->mutable_dims() = {1, width, height, depth}; weights_buffer.data = depthwise_conv_weights_data; - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco |