aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/depthwise_conv.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/kernels/depthwise_conv.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/depthwise_conv.cc61
1 files changed, 45 insertions, 16 deletions
diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv.cc b/tensorflow/contrib/lite/kernels/depthwise_conv.cc
index 347515f289..3e1ce60113 100644
--- a/tensorflow/contrib/lite/kernels/depthwise_conv.cc
+++ b/tensorflow/contrib/lite/kernels/depthwise_conv.cc
@@ -126,23 +126,28 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Matching GetWindowedOutputSize in TensorFlow.
auto padding = params->padding;
- auto compute_out_size = [padding](int imageSize, int filterSize,
- int stride) -> int {
+ auto compute_out_size = [padding](int image_size, int filter_size, int stride,
+ int dilation_rate) -> int {
+ int effective_filter_size = (filter_size - 1) * dilation_rate + 1;
return padding == kTfLitePaddingSame
- ? (imageSize + stride - 1) / stride
+ ? (image_size + stride - 1) / stride
: padding == kTfLitePaddingValid
- ? (imageSize - filterSize + stride) / stride
+ ? (image_size - effective_filter_size + stride) / stride
: 0;
};
- int out_width = compute_out_size(width, filter_width, params->stride_width);
+ int out_width = compute_out_size(width, filter_width, params->stride_width,
+ params->dilation_width_factor);
int out_height =
- compute_out_size(height, filter_height, params->stride_height);
+ compute_out_size(height, filter_height, params->stride_height,
+ params->dilation_height_factor);
- data->padding.height = ComputePadding(params->stride_height, 1, height,
- filter_height, out_height);
+ data->padding.height =
+ ComputePadding(params->stride_height, params->dilation_height_factor,
+ height, filter_height, out_height);
data->padding.width =
- ComputePadding(params->stride_width, 1, width, filter_width, out_width);
+ ComputePadding(params->stride_width, params->dilation_width_factor, width,
+ filter_width, out_width);
// Note that quantized inference requires that all tensors have their
// parameters set. This is usually done during quantized training.
@@ -177,8 +182,19 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
void (*depthwise_conv)(const float*, const Dims<4>&, const float*,
const Dims<4>&, const float*, const Dims<4>&, int, int,
- int, int, int, float, float, float*, const Dims<4>&);
- if (kernel_type == kReference) {
+ int, int, int, int, int, float, float, float*,
+ const Dims<4>&);
+ KernelType effective_kernel_type;
+ // TODO(suharshs): Currently only the reference implementation supports
+ // dilations.
+ if ((params->dilation_width_factor != 1) ||
+ (params->dilation_height_factor != 1)) {
+ effective_kernel_type = kReference;
+ } else {
+ effective_kernel_type = kernel_type;
+ }
+
+ if (effective_kernel_type == kReference) {
depthwise_conv = &reference_ops::DepthwiseConv;
} else {
depthwise_conv = &optimized_ops::DepthwiseConv;
@@ -188,7 +204,8 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
GetTensorData<float>(input), GetTensorDims(input),
GetTensorData<float>(filter), GetTensorDims(filter),
GetTensorData<float>(bias), GetTensorDims(bias), params->stride_width,
- params->stride_height, data->padding.width, data->padding.height,
+ params->stride_height, params->dilation_width_factor,
+ params->dilation_height_factor, data->padding.width, data->padding.height,
params->depth_multiplier, output_activation_min, output_activation_max,
GetTensorData<float>(output), GetTensorDims(output));
}
@@ -204,9 +221,20 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
void (*depthwise_conv)(const uint8*, const Dims<4>&, int32, const uint8*,
const Dims<4>&, int32, const int32*, const Dims<4>&,
- int, int, int, int, int, int32, int32, int, int32,
- int32, uint8*, const Dims<4>&);
- if (kernel_type == kReference) {
+ int, int, int, int, int, int, int, int32, int32, int,
+ int32, int32, uint8*, const Dims<4>&);
+
+ KernelType effective_kernel_type;
+ // TODO(suharshs): Currently only the reference implementation supports
+ // dilations.
+ if ((params->dilation_width_factor != 1) ||
+ (params->dilation_height_factor != 1)) {
+ effective_kernel_type = kReference;
+ } else {
+ effective_kernel_type = kernel_type;
+ }
+
+ if (effective_kernel_type == kReference) {
depthwise_conv = &reference_ops::DepthwiseConv;
} else {
depthwise_conv = &optimized_ops::DepthwiseConv;
@@ -216,7 +244,8 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
GetTensorData<uint8_t>(input), GetTensorDims(input), input_offset,
GetTensorData<uint8_t>(filter), GetTensorDims(filter), filter_offset,
GetTensorData<int32_t>(bias), GetTensorDims(bias), params->stride_width,
- params->stride_height, data->padding.width, data->padding.height,
+ params->stride_height, params->dilation_width_factor,
+ params->dilation_height_factor, data->padding.width, data->padding.height,
params->depth_multiplier, output_offset, data->output_multiplier,
data->output_shift, data->output_activation_min,
data->output_activation_max, GetTensorData<uint8_t>(output),