aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc24
1 files changed, 14 insertions, 10 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc
index e88839be5d..a151012891 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc
@@ -24,29 +24,32 @@ limitations under the License.
namespace toco {
-bool ConvertPureConvToDepthwise::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ConvertPureConvToDepthwise::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
auto conv_it = model->operators.begin() + op_index;
if (conv_it->get()->type != OperatorType::kConv) {
- return false;
+ return ::tensorflow::Status::OK();
}
const auto* conv_op = static_cast<ConvOperator*>(conv_it->get());
if (conv_op->stride_width != conv_op->stride_height) {
- return false;
+ return ::tensorflow::Status::OK();
}
if ((conv_op->dilation_width_factor != 1) ||
(conv_op->dilation_height_factor != 1)) {
// Depthwise conv does not support dilation
- return false;
+ return ::tensorflow::Status::OK();
}
auto& input_array = model->GetArray(conv_op->inputs[0]);
if (!input_array.has_shape()) {
// Shapes not propagated yet
- return false;
+ return ::tensorflow::Status::OK();
}
if (input_array.shape().dims(3) != 1) {
// Not a pure convolution: Conv does accumulation across the depth
// dimension.
- return false;
+ return ::tensorflow::Status::OK();
}
const auto& weights_name = conv_op->inputs[1];
@@ -56,15 +59,15 @@ bool ConvertPureConvToDepthwise::Run(Model* model, std::size_t op_index) {
"Not changing %s to DepthwiseConv because the weights is consumed by "
"another op.",
LogName(*conv_op));
- return false;
+ return ::tensorflow::Status::OK();
}
auto& weights_array = model->GetArray(weights_name);
if (!weights_array.buffer) {
// Yield until the weights are resolved as a constant array.
- return false;
+ return ::tensorflow::Status::OK();
}
if (weights_array.data_type != ArrayDataType::kFloat) {
- return false;
+ return ::tensorflow::Status::OK();
}
// At this point we know we have a pure conv. Rewrite it as DepthwiseConv.
AddMessageF(
@@ -112,7 +115,8 @@ bool ConvertPureConvToDepthwise::Run(Model* model, std::size_t op_index) {
}
*weights_array.mutable_shape()->mutable_dims() = {1, width, height, depth};
weights_buffer.data = depthwise_conv_weights_data;
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco