aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc117
1 files changed, 69 insertions, 48 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc
index d49857cfc2..aac77eb39e 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc
@@ -53,50 +53,11 @@ namespace toco {
// thrown in just for the extra headache. Padding adapts non-conforming input
// sizes, and can be discarded. The bias is necessary, so is kept.
-bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) {
- const auto it = model->operators.begin() + op_index;
- auto* stb_op = it->get();
-
- // 1. IDENTIFY OPERATORS
- // ***************************************************************************
- // SpaceToBatch Op.
- if (stb_op->type != OperatorType::kSpaceToBatchND) {
- return false;
- }
- if (stb_op->inputs.size() != 3) {
- return false;
- }
- CHECK_EQ(stb_op->outputs.size(), 1);
- // Extract the dilation factor from Input[1] of SpaceToBatch
- // TODO(mjmatthews): Support 2D dilation factors.
- const auto& block_shape_array = model->GetArray(stb_op->inputs[1]);
- if (!block_shape_array.buffer) {
- return false;
- }
- CHECK_EQ(block_shape_array.shape().dimensions_count(), 1);
- int dilation_factor =
- block_shape_array.Array::GetBuffer<ArrayDataType::kInt32>().data[0];
-
- // Expand Op
- auto* post_stb_op = GetOpWithInput(*model, stb_op->outputs[0]);
- if (!post_stb_op) {
- return false;
- }
- bool has_expand_op = false;
- if (post_stb_op->type == OperatorType::kExpandDims) {
- has_expand_op = true;
- CHECK_EQ(post_stb_op->inputs.size(), 2);
- CHECK_EQ(post_stb_op->outputs.size(), 1);
- }
-
- // Conv Op
- const string& input_of_conv_op =
- has_expand_op ? post_stb_op->outputs[0] : stb_op->outputs[0];
- auto* conv_base_op = GetOpWithInput(*model, input_of_conv_op);
- if (conv_base_op->type != OperatorType::kConv) {
- return false;
- }
- auto* conv_op = static_cast<ConvOperator*>(conv_base_op);
+template <typename T>
+bool ResolveDilatedConv(Model* model, Operator* conv_base_op, Operator* stb_op,
+ Operator* post_stb_op, bool has_expand_op,
+ int dilation_factor) {
+ auto* conv_op = static_cast<T*>(conv_base_op);
if (conv_op->inputs.size() != 2) {
// The conv op must only have weights, no bias.
return false;
@@ -158,8 +119,6 @@ bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) {
CHECK_EQ(bias_add_op->inputs.size(), 2);
CHECK_EQ(bias_add_op->outputs.size(), 1);
- LOG(INFO) << "Identified sub-network emulating dilated convolution.";
-
// 2. RE-WIRE OPERATORS
// ***************************************************************************
// Re-use the existing Conv2D op.
@@ -206,9 +165,71 @@ bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) {
DeleteArrayIfUnused(stb_op_inputs[1], model);
DeleteArrayIfUnused(stb_op_inputs[2], model);
- LOG(INFO) << "Replaced with Dilated Conv2D op outputting \""
- << conv_op->outputs[0] << "\".";
return true;
}
+bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) {
+ const auto it = model->operators.begin() + op_index;
+ auto* stb_op = it->get();
+
+ // 1. IDENTIFY OPERATORS
+ // ***************************************************************************
+ // SpaceToBatch Op.
+ if (stb_op->type != OperatorType::kSpaceToBatchND) {
+ return false;
+ }
+ if (stb_op->inputs.size() != 3) {
+ return false;
+ }
+ CHECK_EQ(stb_op->outputs.size(), 1);
+ // Extract the dilation factor from Input[1] of SpaceToBatch
+ // TODO(mjmatthews): Support 2D dilation factors.
+ const auto& block_shape_array = model->GetArray(stb_op->inputs[1]);
+ if (!block_shape_array.buffer) {
+ return false;
+ }
+ CHECK_EQ(block_shape_array.shape().dimensions_count(), 1);
+ int dilation_factor =
+ block_shape_array.Array::GetBuffer<ArrayDataType::kInt32>().data[0];
+
+ // Expand Op
+ auto* post_stb_op = GetOpWithInput(*model, stb_op->outputs[0]);
+ if (!post_stb_op) {
+ return false;
+ }
+ bool has_expand_op = false;
+ if (post_stb_op->type == OperatorType::kExpandDims) {
+ has_expand_op = true;
+ CHECK_EQ(post_stb_op->inputs.size(), 2);
+ CHECK_EQ(post_stb_op->outputs.size(), 1);
+ }
+
+ // Conv Op
+ const string& input_of_conv_op =
+ has_expand_op ? post_stb_op->outputs[0] : stb_op->outputs[0];
+ auto* conv_base_op = GetOpWithInput(*model, input_of_conv_op);
+ bool changed = false;
+ if (conv_base_op->type == OperatorType::kConv) {
+ changed = ResolveDilatedConv<ConvOperator>(model, conv_base_op, stb_op,
+ post_stb_op, has_expand_op,
+ dilation_factor);
+ if (changed) {
+ LOG(INFO) << "Replaced sub-network with Dilated Conv2D op outputting \""
+ << conv_base_op->outputs[0] << "\".";
+ }
+ } else if (identify_depthwise_conv_ &&
+ conv_base_op->type == OperatorType::kDepthwiseConv) {
+ changed = ResolveDilatedConv<DepthwiseConvOperator>(
+ model, conv_base_op, stb_op, post_stb_op, has_expand_op,
+ dilation_factor);
+ if (changed) {
+ LOG(INFO)
+ << "Replaced sub-netork with Dilated DepthwiseConv2D op outputting \""
+ << conv_base_op->outputs[0] << "\".";
+ }
+ }
+
+ return changed;
+}
+
} // namespace toco