aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Suharsh Sivakumar <suharshs@google.com>2018-09-19 17:40:09 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-19 17:44:12 -0700
commit0ab89a599bdb9885532785a5e7b6bfe346e09ee3 (patch)
tree79bd3c43e9d9485f12054ef1a5c719a3b00a027e
parent5d2047029a77545c97c0fdf74d9c03c92d1dcb88 (diff)
TOCO transformations updated to support dilated depthwise convolution.
PiperOrigin-RevId: 213729750
-rw-r--r--tensorflow/contrib/lite/toco/export_tensorflow.cc11
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h12
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc117
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc3
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc17
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.cc9
6 files changed, 118 insertions, 51 deletions
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc
index 3a534300ae..3d1eb3978c 100644
--- a/tensorflow/contrib/lite/toco/export_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc
@@ -470,6 +470,17 @@ void ConvertDepthwiseConvOperator(const Model& model,
strides.mutable_list()->add_i(src_op.stride_height);
strides.mutable_list()->add_i(src_op.stride_width);
strides.mutable_list()->add_i(1);
+ // TODO(b/): To return a working TF GraphDef, we should be returning the
+ // correct SpaceToBatchNd and BatchToSpaceND operation before and after the
+ // conv since TF doesn't support dilations.
+ if ((src_op.dilation_width_factor != 1) ||
+ (src_op.dilation_height_factor != 1)) {
+ auto& dilations = (*dc2d_op->mutable_attr())["dilations"];
+ dilations.mutable_list()->add_i(1);
+ dilations.mutable_list()->add_i(src_op.dilation_height_factor);
+ dilations.mutable_list()->add_i(src_op.dilation_width_factor);
+ dilations.mutable_list()->add_i(1);
+ }
string padding;
if (src_op.padding.type == PaddingType::kSame) {
padding = "SAME";
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
index fdd0632451..4d213b3f9c 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
+++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
@@ -133,7 +133,6 @@ DECLARE_GRAPH_TRANSFORMATION(MergeLstmCellInputs)
DECLARE_GRAPH_TRANSFORMATION(MergeReshapeIntoPrecedingTranspose)
DECLARE_GRAPH_TRANSFORMATION(IdentifyRelu1)
DECLARE_GRAPH_TRANSFORMATION(IdentifyPRelu)
-DECLARE_GRAPH_TRANSFORMATION(IdentifyDilatedConv)
DECLARE_GRAPH_TRANSFORMATION(MakeInitialDequantizeOperator)
DECLARE_GRAPH_TRANSFORMATION(MoveBinaryOperatorBeforeReshape)
DECLARE_GRAPH_TRANSFORMATION(PropagateActivationFunctionIntoConstants)
@@ -266,6 +265,17 @@ class EnsureUint8WeightsSafeForFastInt8Kernels : public GraphTransformation {
bool has_default_ranges_flag_ = false;
};
+class IdentifyDilatedConv : public GraphTransformation {
+ public:
+ bool Run(Model* model, std::size_t op_index) override;
+ const char* Name() const override { return "IdentifyDilatedConv"; }
+ bool identify_depthwise_conv() const { return identify_depthwise_conv_; }
+ void set_identify_depthwise_conv(bool val) { identify_depthwise_conv_ = val; }
+
+ private:
+ bool identify_depthwise_conv_ = true;
+};
+
#undef DECLARE_GRAPH_TRANSFORMATION
} // end namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc
index d49857cfc2..aac77eb39e 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc
@@ -53,50 +53,11 @@ namespace toco {
// thrown in just for the extra headache. Padding adapts non-conforming input
// sizes, and can be discarded. The bias is necessary, so is kept.
-bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) {
- const auto it = model->operators.begin() + op_index;
- auto* stb_op = it->get();
-
- // 1. IDENTIFY OPERATORS
- // ***************************************************************************
- // SpaceToBatch Op.
- if (stb_op->type != OperatorType::kSpaceToBatchND) {
- return false;
- }
- if (stb_op->inputs.size() != 3) {
- return false;
- }
- CHECK_EQ(stb_op->outputs.size(), 1);
- // Extract the dilation factor from Input[1] of SpaceToBatch
- // TODO(mjmatthews): Support 2D dilation factors.
- const auto& block_shape_array = model->GetArray(stb_op->inputs[1]);
- if (!block_shape_array.buffer) {
- return false;
- }
- CHECK_EQ(block_shape_array.shape().dimensions_count(), 1);
- int dilation_factor =
- block_shape_array.Array::GetBuffer<ArrayDataType::kInt32>().data[0];
-
- // Expand Op
- auto* post_stb_op = GetOpWithInput(*model, stb_op->outputs[0]);
- if (!post_stb_op) {
- return false;
- }
- bool has_expand_op = false;
- if (post_stb_op->type == OperatorType::kExpandDims) {
- has_expand_op = true;
- CHECK_EQ(post_stb_op->inputs.size(), 2);
- CHECK_EQ(post_stb_op->outputs.size(), 1);
- }
-
- // Conv Op
- const string& input_of_conv_op =
- has_expand_op ? post_stb_op->outputs[0] : stb_op->outputs[0];
- auto* conv_base_op = GetOpWithInput(*model, input_of_conv_op);
- if (conv_base_op->type != OperatorType::kConv) {
- return false;
- }
- auto* conv_op = static_cast<ConvOperator*>(conv_base_op);
+template <typename T>
+bool ResolveDilatedConv(Model* model, Operator* conv_base_op, Operator* stb_op,
+ Operator* post_stb_op, bool has_expand_op,
+ int dilation_factor) {
+ auto* conv_op = static_cast<T*>(conv_base_op);
if (conv_op->inputs.size() != 2) {
// The conv op must only have weights, no bias.
return false;
@@ -158,8 +119,6 @@ bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) {
CHECK_EQ(bias_add_op->inputs.size(), 2);
CHECK_EQ(bias_add_op->outputs.size(), 1);
- LOG(INFO) << "Identified sub-network emulating dilated convolution.";
-
// 2. RE-WIRE OPERATORS
// ***************************************************************************
// Re-use the existing Conv2D op.
@@ -206,9 +165,71 @@ bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) {
DeleteArrayIfUnused(stb_op_inputs[1], model);
DeleteArrayIfUnused(stb_op_inputs[2], model);
- LOG(INFO) << "Replaced with Dilated Conv2D op outputting \""
- << conv_op->outputs[0] << "\".";
return true;
}
+bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) {
+ const auto it = model->operators.begin() + op_index;
+ auto* stb_op = it->get();
+
+ // 1. IDENTIFY OPERATORS
+ // ***************************************************************************
+ // SpaceToBatch Op.
+ if (stb_op->type != OperatorType::kSpaceToBatchND) {
+ return false;
+ }
+ if (stb_op->inputs.size() != 3) {
+ return false;
+ }
+ CHECK_EQ(stb_op->outputs.size(), 1);
+ // Extract the dilation factor from Input[1] of SpaceToBatch
+ // TODO(mjmatthews): Support 2D dilation factors.
+ const auto& block_shape_array = model->GetArray(stb_op->inputs[1]);
+ if (!block_shape_array.buffer) {
+ return false;
+ }
+ CHECK_EQ(block_shape_array.shape().dimensions_count(), 1);
+ int dilation_factor =
+ block_shape_array.Array::GetBuffer<ArrayDataType::kInt32>().data[0];
+
+ // Expand Op
+ auto* post_stb_op = GetOpWithInput(*model, stb_op->outputs[0]);
+ if (!post_stb_op) {
+ return false;
+ }
+ bool has_expand_op = false;
+ if (post_stb_op->type == OperatorType::kExpandDims) {
+ has_expand_op = true;
+ CHECK_EQ(post_stb_op->inputs.size(), 2);
+ CHECK_EQ(post_stb_op->outputs.size(), 1);
+ }
+
+ // Conv Op
+ const string& input_of_conv_op =
+ has_expand_op ? post_stb_op->outputs[0] : stb_op->outputs[0];
+ auto* conv_base_op = GetOpWithInput(*model, input_of_conv_op);
+ bool changed = false;
+ if (conv_base_op->type == OperatorType::kConv) {
+ changed = ResolveDilatedConv<ConvOperator>(model, conv_base_op, stb_op,
+ post_stb_op, has_expand_op,
+ dilation_factor);
+ if (changed) {
+ LOG(INFO) << "Replaced sub-network with Dilated Conv2D op outputting \""
+ << conv_base_op->outputs[0] << "\".";
+ }
+ } else if (identify_depthwise_conv_ &&
+ conv_base_op->type == OperatorType::kDepthwiseConv) {
+ changed = ResolveDilatedConv<DepthwiseConvOperator>(
+ model, conv_base_op, stb_op, post_stb_op, has_expand_op,
+ dilation_factor);
+ if (changed) {
+ LOG(INFO)
+ << "Replaced sub-netork with Dilated DepthwiseConv2D op outputting \""
+ << conv_base_op->outputs[0] << "\".";
+ }
+ }
+
+ return changed;
+}
+
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index 6c72e20121..f943da6d85 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -285,7 +285,8 @@ void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) {
const int kheight = weights_shape.dims(1);
const int kwidth = weights_shape.dims(2);
ComputeConvSizes(input_shape, output_depth, kwidth, kheight, op->stride_width,
- op->stride_height, 1, 1, op->padding.type,
+ op->stride_height, op->dilation_width_factor,
+ op->dilation_height_factor, op->padding.type,
model->GetArray(output_name).mutable_shape(),
&op->padding.GetOrCreateFixedPadding());
}
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 4c678e7e73..e02d000e7e 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -641,6 +641,23 @@ tensorflow::Status ConvertDepthwiseConvOperator(
CHECK_EQ(strides.i(3), 1);
conv->stride_height = strides.i(1);
conv->stride_width = strides.i(2);
+ if (HasAttr(node, "dilations")) {
+ const auto& dilations = GetListAttr(node, "dilations");
+ TF_RETURN_IF_ERROR(
+ ExpectValue(dilations.i_size(), 4, "number of dilations"));
+ if (dilations.i(0) != 1 || dilations.i(3) != 1) {
+ return tensorflow::errors::InvalidArgument(absl::StrCat(
+ "Can only import Conv ops with dilation along the height "
+ "(1st) or width (2nd) axis. TensorFlow op \"",
+ node.name(), "\" had dilations:[ ", dilations.i(0), ", ",
+ dilations.i(1), ", ", dilations.i(2), ", ", dilations.i(3), "]."));
+ }
+ conv->dilation_height_factor = dilations.i(1);
+ conv->dilation_width_factor = dilations.i(2);
+ } else {
+ conv->dilation_height_factor = 1;
+ conv->dilation_width_factor = 1;
+ }
const auto& padding = GetStringAttr(node, "padding");
if (padding == "SAME") {
conv->padding.type = PaddingType::kSame;
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
index 28d31e3797..a08b02485f 100644
--- a/tensorflow/contrib/lite/toco/toco_tooling.cc
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -101,7 +101,6 @@ void MakeGeneralGraphTransformationsSet(
transformations->Add(new ResolveTensorFlowSwitch);
transformations->Add(new ResolveTensorFlowConcat);
transformations->Add(new ResolveMultiplyByZero);
- transformations->Add(new IdentifyDilatedConv);
transformations->Add(new IdentifyL2Normalization);
transformations->Add(new IdentifyL2Pool);
transformations->Add(new IdentifyRelu1);
@@ -282,6 +281,14 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
}
}
transformations.Add(new ResolveConstantConcatenation);
+ // TODO(b/116063589): TF GraphDef doesn't support dilations on its depthwise
+ // conv, so we need to make sure we don't convert to dilated depthwise conv
+ // when outputing to TF GraphDef.
+ auto* identify_dilated_conv = new IdentifyDilatedConv;
+ if (output_format == TENSORFLOW_GRAPHDEF) {
+ identify_dilated_conv->set_identify_depthwise_conv(false);
+ }
+ transformations.Add(identify_dilated_conv);
RunGraphTransformations(model, "general graph transformations",
transformations);