aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco
diff options
context:
space:
mode:
authorGravatar Suharsh Sivakumar <suharshs@google.com>2018-09-13 15:01:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-13 15:09:10 -0700
commitfb50c8e9a3cb2ccfac9cf4a847d5841cba80b524 (patch)
tree969d44684c674fad7ef775323f6e150e616890d9 /tensorflow/contrib/lite/toco
parente8af4e1bb9496c111530e88263fb1b8dac8bdde9 (diff)
Dilated Depthwise Conv reference implementations.
PiperOrigin-RevId: 212884951
Diffstat (limited to 'tensorflow/contrib/lite/toco')
-rw-r--r--tensorflow/contrib/lite/toco/model.h5
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc14
2 files changed, 17 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index 2e100e37f6..164b70f2df 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -477,6 +477,11 @@ struct DepthwiseConvOperator : Operator {
int stride_height = 0;
int stride_width = 0;
int depth_multiplier = 0;
+ // A dilation_rate of 0 is invalid and this field is an optional attribute.
+ // Thus initializing it to 1 to allow default conv behavior when the
+ // attribute is not present.
+ int dilation_width_factor = 1;
+ int dilation_height_factor = 1;
};
// Depth-to-space transform operator.
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index 5486012176..1061e7c7c4 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -107,7 +107,8 @@ class DepthwiseConvolution
ActivationFunction::Serialize(op.fused_activation_function);
return ::tflite::CreateDepthwiseConv2DOptions(
*builder, padding, op.stride_width, op.stride_height,
- op.depth_multiplier, activation_function);
+ op.depth_multiplier, activation_function, op.dilation_width_factor,
+ op.dilation_height_factor);
}
void ReadOptions(const TfLiteOptions& options,
@@ -118,9 +119,18 @@ class DepthwiseConvolution
op->depth_multiplier = options.depth_multiplier();
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
+ op->dilation_width_factor = options.dilation_w_factor();
+ op->dilation_height_factor = options.dilation_h_factor();
}
- int GetVersion(const Operator& op) const override { return 1; }
+ int GetVersion(const Operator& op) const override {
+ const auto& conv_op = static_cast<const DepthwiseConvOperator&>(op);
+ if (conv_op.dilation_width_factor != 1 ||
+ conv_op.dilation_height_factor != 1) {
+ return 2;
+ }
+ return 1;
+ }
};
class Add : public BuiltinOperator<AddOperator, ::tflite::AddOptions,