diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc | 117 |
1 files changed, 69 insertions, 48 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc index d49857cfc2..aac77eb39e 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc @@ -53,50 +53,11 @@ namespace toco { // thrown in just for the extra headache. Padding adapts non-conforming input // sizes, and can be discarded. The bias is necessary, so is kept. -bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) { - const auto it = model->operators.begin() + op_index; - auto* stb_op = it->get(); - - // 1. IDENTIFY OPERATORS - // *************************************************************************** - // SpaceToBatch Op. - if (stb_op->type != OperatorType::kSpaceToBatchND) { - return false; - } - if (stb_op->inputs.size() != 3) { - return false; - } - CHECK_EQ(stb_op->outputs.size(), 1); - // Extract the dilation factor from Input[1] of SpaceToBatch - // TODO(mjmatthews): Support 2D dilation factors. - const auto& block_shape_array = model->GetArray(stb_op->inputs[1]); - if (!block_shape_array.buffer) { - return false; - } - CHECK_EQ(block_shape_array.shape().dimensions_count(), 1); - int dilation_factor = - block_shape_array.Array::GetBuffer<ArrayDataType::kInt32>().data[0]; - - // Expand Op - auto* post_stb_op = GetOpWithInput(*model, stb_op->outputs[0]); - if (!post_stb_op) { - return false; - } - bool has_expand_op = false; - if (post_stb_op->type == OperatorType::kExpandDims) { - has_expand_op = true; - CHECK_EQ(post_stb_op->inputs.size(), 2); - CHECK_EQ(post_stb_op->outputs.size(), 1); - } - - // Conv Op - const string& input_of_conv_op = - has_expand_op ? post_stb_op->outputs[0] : stb_op->outputs[0]; - auto* conv_base_op = GetOpWithInput(*model, input_of_conv_op); - if (conv_base_op->type != OperatorType::kConv) { - return false; - } - auto* conv_op = static_cast<ConvOperator*>(conv_base_op); +template <typename T> +bool ResolveDilatedConv(Model* model, Operator* conv_base_op, Operator* stb_op, + Operator* post_stb_op, bool has_expand_op, + int dilation_factor) { + auto* conv_op = static_cast<T*>(conv_base_op); if (conv_op->inputs.size() != 2) { // The conv op must only have weights, no bias. return false; @@ -158,8 +119,6 @@ bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) { CHECK_EQ(bias_add_op->inputs.size(), 2); CHECK_EQ(bias_add_op->outputs.size(), 1); - LOG(INFO) << "Identified sub-network emulating dilated convolution."; - // 2. RE-WIRE OPERATORS // *************************************************************************** // Re-use the existing Conv2D op. @@ -206,9 +165,71 @@ bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) { DeleteArrayIfUnused(stb_op_inputs[1], model); DeleteArrayIfUnused(stb_op_inputs[2], model); - LOG(INFO) << "Replaced with Dilated Conv2D op outputting \"" - << conv_op->outputs[0] << "\"."; return true; } +bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) { + const auto it = model->operators.begin() + op_index; + auto* stb_op = it->get(); + + // 1. IDENTIFY OPERATORS + // *************************************************************************** + // SpaceToBatch Op. + if (stb_op->type != OperatorType::kSpaceToBatchND) { + return false; + } + if (stb_op->inputs.size() != 3) { + return false; + } + CHECK_EQ(stb_op->outputs.size(), 1); + // Extract the dilation factor from Input[1] of SpaceToBatch + // TODO(mjmatthews): Support 2D dilation factors. + const auto& block_shape_array = model->GetArray(stb_op->inputs[1]); + if (!block_shape_array.buffer) { + return false; + } + CHECK_EQ(block_shape_array.shape().dimensions_count(), 1); + int dilation_factor = + block_shape_array.Array::GetBuffer<ArrayDataType::kInt32>().data[0]; + + // Expand Op + auto* post_stb_op = GetOpWithInput(*model, stb_op->outputs[0]); + if (!post_stb_op) { + return false; + } + bool has_expand_op = false; + if (post_stb_op->type == OperatorType::kExpandDims) { + has_expand_op = true; + CHECK_EQ(post_stb_op->inputs.size(), 2); + CHECK_EQ(post_stb_op->outputs.size(), 1); + } + + // Conv Op + const string& input_of_conv_op = + has_expand_op ? post_stb_op->outputs[0] : stb_op->outputs[0]; + auto* conv_base_op = GetOpWithInput(*model, input_of_conv_op); + bool changed = false; + if (conv_base_op->type == OperatorType::kConv) { + changed = ResolveDilatedConv<ConvOperator>(model, conv_base_op, stb_op, + post_stb_op, has_expand_op, + dilation_factor); + if (changed) { + LOG(INFO) << "Replaced sub-network with Dilated Conv2D op outputting \"" + << conv_base_op->outputs[0] << "\"."; + } + } else if (identify_depthwise_conv_ && + conv_base_op->type == OperatorType::kDepthwiseConv) { + changed = ResolveDilatedConv<DepthwiseConvOperator>( + model, conv_base_op, stb_op, post_stb_op, has_expand_op, + dilation_factor); + if (changed) { + LOG(INFO) + << "Replaced sub-netork with Dilated DepthwiseConv2D op outputting \"" + << conv_base_op->outputs[0] << "\"."; + } + } + + return changed; +} + } // namespace toco |