diff options
author | Suharsh Sivakumar <suharshs@google.com> | 2018-09-13 15:01:08 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-13 15:09:10 -0700 |
commit | fb50c8e9a3cb2ccfac9cf4a847d5841cba80b524 (patch) | |
tree | 969d44684c674fad7ef775323f6e150e616890d9 /tensorflow/contrib/lite/toco | |
parent | e8af4e1bb9496c111530e88263fb1b8dac8bdde9 (diff) |
Dilated Depthwise Conv reference implementations.
PiperOrigin-RevId: 212884951
Diffstat (limited to 'tensorflow/contrib/lite/toco')
-rw-r--r-- | tensorflow/contrib/lite/toco/model.h | 5 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/tflite/operator.cc | 14 |
2 files changed, 17 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h index 2e100e37f6..164b70f2df 100644 --- a/tensorflow/contrib/lite/toco/model.h +++ b/tensorflow/contrib/lite/toco/model.h @@ -477,6 +477,11 @@ struct DepthwiseConvOperator : Operator { int stride_height = 0; int stride_width = 0; int depth_multiplier = 0; + // A dilation_rate of 0 is invalid and this field is an optional attribute. + // Thus initializing it to 1 to allow default conv behavior when the + // attribute is not present. + int dilation_width_factor = 1; + int dilation_height_factor = 1; }; // Depth-to-space transform operator. diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index 5486012176..1061e7c7c4 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -107,7 +107,8 @@ class DepthwiseConvolution ActivationFunction::Serialize(op.fused_activation_function); return ::tflite::CreateDepthwiseConv2DOptions( *builder, padding, op.stride_width, op.stride_height, - op.depth_multiplier, activation_function); + op.depth_multiplier, activation_function, op.dilation_width_factor, + op.dilation_height_factor); } void ReadOptions(const TfLiteOptions& options, @@ -118,9 +119,18 @@ class DepthwiseConvolution op->depth_multiplier = options.depth_multiplier(); op->fused_activation_function = ActivationFunction::Deserialize(options.fused_activation_function()); + op->dilation_width_factor = options.dilation_w_factor(); + op->dilation_height_factor = options.dilation_h_factor(); } - int GetVersion(const Operator& op) const override { return 1; } + int GetVersion(const Operator& op) const override { + const auto& conv_op = static_cast<const DepthwiseConvOperator&>(op); + if (conv_op.dilation_width_factor != 1 || + conv_op.dilation_height_factor != 1) { + return 2; + } + return 1; + } }; class Add : public BuiltinOperator<AddOperator, ::tflite::AddOptions, |