diff options
author | Suharsh Sivakumar <suharshs@google.com> | 2018-09-19 17:40:09 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-19 17:44:12 -0700 |
commit | 0ab89a599bdb9885532785a5e7b6bfe346e09ee3 (patch) | |
tree | 79bd3c43e9d9485f12054ef1a5c719a3b00a027e | |
parent | 5d2047029a77545c97c0fdf74d9c03c92d1dcb88 (diff) |
TOCO transformations updated to support dilated depthwise convolution.
PiperOrigin-RevId: 213729750
6 files changed, 118 insertions, 51 deletions
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc index 3a534300ae..3d1eb3978c 100644 --- a/tensorflow/contrib/lite/toco/export_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc @@ -470,6 +470,17 @@ void ConvertDepthwiseConvOperator(const Model& model, strides.mutable_list()->add_i(src_op.stride_height); strides.mutable_list()->add_i(src_op.stride_width); strides.mutable_list()->add_i(1); + // TODO(b/): To return a working TF GraphDef, we should be returning the + // correct SpaceToBatchNd and BatchToSpaceND operation before and after the + // conv since TF doesn't support dilations. + if ((src_op.dilation_width_factor != 1) || + (src_op.dilation_height_factor != 1)) { + auto& dilations = (*dc2d_op->mutable_attr())["dilations"]; + dilations.mutable_list()->add_i(1); + dilations.mutable_list()->add_i(src_op.dilation_height_factor); + dilations.mutable_list()->add_i(src_op.dilation_width_factor); + dilations.mutable_list()->add_i(1); + } string padding; if (src_op.padding.type == PaddingType::kSame) { padding = "SAME"; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h index fdd0632451..4d213b3f9c 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h +++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h @@ -133,7 +133,6 @@ DECLARE_GRAPH_TRANSFORMATION(MergeLstmCellInputs) DECLARE_GRAPH_TRANSFORMATION(MergeReshapeIntoPrecedingTranspose) DECLARE_GRAPH_TRANSFORMATION(IdentifyRelu1) DECLARE_GRAPH_TRANSFORMATION(IdentifyPRelu) -DECLARE_GRAPH_TRANSFORMATION(IdentifyDilatedConv) DECLARE_GRAPH_TRANSFORMATION(MakeInitialDequantizeOperator) DECLARE_GRAPH_TRANSFORMATION(MoveBinaryOperatorBeforeReshape) DECLARE_GRAPH_TRANSFORMATION(PropagateActivationFunctionIntoConstants) @@ -266,6 +265,17 @@ class EnsureUint8WeightsSafeForFastInt8Kernels : public GraphTransformation { bool has_default_ranges_flag_ = false; }; +class IdentifyDilatedConv : public GraphTransformation { + public: + bool Run(Model* model, std::size_t op_index) override; + const char* Name() const override { return "IdentifyDilatedConv"; } + bool identify_depthwise_conv() const { return identify_depthwise_conv_; } + void set_identify_depthwise_conv(bool val) { identify_depthwise_conv_ = val; } + + private: + bool identify_depthwise_conv_ = true; +}; + #undef DECLARE_GRAPH_TRANSFORMATION } // end namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc index d49857cfc2..aac77eb39e 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc @@ -53,50 +53,11 @@ namespace toco { // thrown in just for the extra headache. Padding adapts non-conforming input // sizes, and can be discarded. The bias is necessary, so is kept. -bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) { - const auto it = model->operators.begin() + op_index; - auto* stb_op = it->get(); - - // 1. IDENTIFY OPERATORS - // *************************************************************************** - // SpaceToBatch Op. - if (stb_op->type != OperatorType::kSpaceToBatchND) { - return false; - } - if (stb_op->inputs.size() != 3) { - return false; - } - CHECK_EQ(stb_op->outputs.size(), 1); - // Extract the dilation factor from Input[1] of SpaceToBatch - // TODO(mjmatthews): Support 2D dilation factors. - const auto& block_shape_array = model->GetArray(stb_op->inputs[1]); - if (!block_shape_array.buffer) { - return false; - } - CHECK_EQ(block_shape_array.shape().dimensions_count(), 1); - int dilation_factor = - block_shape_array.Array::GetBuffer<ArrayDataType::kInt32>().data[0]; - - // Expand Op - auto* post_stb_op = GetOpWithInput(*model, stb_op->outputs[0]); - if (!post_stb_op) { - return false; - } - bool has_expand_op = false; - if (post_stb_op->type == OperatorType::kExpandDims) { - has_expand_op = true; - CHECK_EQ(post_stb_op->inputs.size(), 2); - CHECK_EQ(post_stb_op->outputs.size(), 1); - } - - // Conv Op - const string& input_of_conv_op = - has_expand_op ? post_stb_op->outputs[0] : stb_op->outputs[0]; - auto* conv_base_op = GetOpWithInput(*model, input_of_conv_op); - if (conv_base_op->type != OperatorType::kConv) { - return false; - } - auto* conv_op = static_cast<ConvOperator*>(conv_base_op); +template <typename T> +bool ResolveDilatedConv(Model* model, Operator* conv_base_op, Operator* stb_op, + Operator* post_stb_op, bool has_expand_op, + int dilation_factor) { + auto* conv_op = static_cast<T*>(conv_base_op); if (conv_op->inputs.size() != 2) { // The conv op must only have weights, no bias. return false; @@ -158,8 +119,6 @@ bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) { CHECK_EQ(bias_add_op->inputs.size(), 2); CHECK_EQ(bias_add_op->outputs.size(), 1); - LOG(INFO) << "Identified sub-network emulating dilated convolution."; - // 2. RE-WIRE OPERATORS // *************************************************************************** // Re-use the existing Conv2D op. @@ -206,9 +165,71 @@ bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) { DeleteArrayIfUnused(stb_op_inputs[1], model); DeleteArrayIfUnused(stb_op_inputs[2], model); - LOG(INFO) << "Replaced with Dilated Conv2D op outputting \"" - << conv_op->outputs[0] << "\"."; return true; } +bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) { + const auto it = model->operators.begin() + op_index; + auto* stb_op = it->get(); + + // 1. IDENTIFY OPERATORS + // *************************************************************************** + // SpaceToBatch Op. + if (stb_op->type != OperatorType::kSpaceToBatchND) { + return false; + } + if (stb_op->inputs.size() != 3) { + return false; + } + CHECK_EQ(stb_op->outputs.size(), 1); + // Extract the dilation factor from Input[1] of SpaceToBatch + // TODO(mjmatthews): Support 2D dilation factors. + const auto& block_shape_array = model->GetArray(stb_op->inputs[1]); + if (!block_shape_array.buffer) { + return false; + } + CHECK_EQ(block_shape_array.shape().dimensions_count(), 1); + int dilation_factor = + block_shape_array.Array::GetBuffer<ArrayDataType::kInt32>().data[0]; + + // Expand Op + auto* post_stb_op = GetOpWithInput(*model, stb_op->outputs[0]); + if (!post_stb_op) { + return false; + } + bool has_expand_op = false; + if (post_stb_op->type == OperatorType::kExpandDims) { + has_expand_op = true; + CHECK_EQ(post_stb_op->inputs.size(), 2); + CHECK_EQ(post_stb_op->outputs.size(), 1); + } + + // Conv Op + const string& input_of_conv_op = + has_expand_op ? post_stb_op->outputs[0] : stb_op->outputs[0]; + auto* conv_base_op = GetOpWithInput(*model, input_of_conv_op); + bool changed = false; + if (conv_base_op->type == OperatorType::kConv) { + changed = ResolveDilatedConv<ConvOperator>(model, conv_base_op, stb_op, + post_stb_op, has_expand_op, + dilation_factor); + if (changed) { + LOG(INFO) << "Replaced sub-network with Dilated Conv2D op outputting \"" + << conv_base_op->outputs[0] << "\"."; + } + } else if (identify_depthwise_conv_ && + conv_base_op->type == OperatorType::kDepthwiseConv) { + changed = ResolveDilatedConv<DepthwiseConvOperator>( + model, conv_base_op, stb_op, post_stb_op, has_expand_op, + dilation_factor); + if (changed) { + LOG(INFO) + << "Replaced sub-netork with Dilated DepthwiseConv2D op outputting \"" + << conv_base_op->outputs[0] << "\"."; + } + } + + return changed; +} + } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index 6c72e20121..f943da6d85 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -285,7 +285,8 @@ void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) { const int kheight = weights_shape.dims(1); const int kwidth = weights_shape.dims(2); ComputeConvSizes(input_shape, output_depth, kwidth, kheight, op->stride_width, - op->stride_height, 1, 1, op->padding.type, + op->stride_height, op->dilation_width_factor, + op->dilation_height_factor, op->padding.type, model->GetArray(output_name).mutable_shape(), &op->padding.GetOrCreateFixedPadding()); } diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index 4c678e7e73..e02d000e7e 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -641,6 +641,23 @@ tensorflow::Status ConvertDepthwiseConvOperator( CHECK_EQ(strides.i(3), 1); conv->stride_height = strides.i(1); conv->stride_width = strides.i(2); + if (HasAttr(node, "dilations")) { + const auto& dilations = GetListAttr(node, "dilations"); + TF_RETURN_IF_ERROR( + ExpectValue(dilations.i_size(), 4, "number of dilations")); + if (dilations.i(0) != 1 || dilations.i(3) != 1) { + return tensorflow::errors::InvalidArgument(absl::StrCat( + "Can only import Conv ops with dilation along the height " + "(1st) or width (2nd) axis. TensorFlow op \"", + node.name(), "\" had dilations:[ ", dilations.i(0), ", ", + dilations.i(1), ", ", dilations.i(2), ", ", dilations.i(3), "].")); + } + conv->dilation_height_factor = dilations.i(1); + conv->dilation_width_factor = dilations.i(2); + } else { + conv->dilation_height_factor = 1; + conv->dilation_width_factor = 1; + } const auto& padding = GetStringAttr(node, "padding"); if (padding == "SAME") { conv->padding.type = PaddingType::kSame; diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc index 28d31e3797..a08b02485f 100644 --- a/tensorflow/contrib/lite/toco/toco_tooling.cc +++ b/tensorflow/contrib/lite/toco/toco_tooling.cc @@ -101,7 +101,6 @@ void MakeGeneralGraphTransformationsSet( transformations->Add(new ResolveTensorFlowSwitch); transformations->Add(new ResolveTensorFlowConcat); transformations->Add(new ResolveMultiplyByZero); - transformations->Add(new IdentifyDilatedConv); transformations->Add(new IdentifyL2Normalization); transformations->Add(new IdentifyL2Pool); transformations->Add(new IdentifyRelu1); @@ -282,6 +281,14 @@ void Transform(const TocoFlags& toco_flags, Model* model) { } } transformations.Add(new ResolveConstantConcatenation); + // TODO(b/116063589): TF GraphDef doesn't support dilations on its depthwise + // conv, so we need to make sure we don't convert to dilated depthwise conv + // when outputing to TF GraphDef. + auto* identify_dilated_conv = new IdentifyDilatedConv; + if (output_format == TENSORFLOW_GRAPHDEF) { + identify_dilated_conv->set_identify_depthwise_conv(false); + } + transformations.Add(identify_dilated_conv); RunGraphTransformations(model, "general graph transformations", transformations); |