aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/conv.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-13 08:12:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-13 08:16:46 -0700
commit91c31997e6854a3d07acc76381cff7436df1c1dd (patch)
treeee4f38d890771de2f2144c718a6421673eb81caf /tensorflow/contrib/lite/kernels/conv.cc
parentf9de043501e401af73aa02ab950864534f07c1df (diff)
Add support to TFLite for dilated convolution.
PiperOrigin-RevId: 192770919
Diffstat (limited to 'tensorflow/contrib/lite/kernels/conv.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/conv.cc67
1 files changed, 42 insertions, 25 deletions
diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc
index 18ff33bf9f..3b467b3aa2 100644
--- a/tensorflow/contrib/lite/kernels/conv.cc
+++ b/tensorflow/contrib/lite/kernels/conv.cc
@@ -225,22 +225,27 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Matching GetWindowedOutputSize in TensorFlow.
auto padding = params->padding;
- auto computeOutSize = [padding](int imageSize, int filterSize,
- int stride) -> int {
+ auto computeOutSize = [padding](int imageSize, int filterSize, int stride,
+ int dilationRate) -> int {
+ int effectiveFilterSize = (filterSize - 1) * dilationRate + 1;
return padding == kTfLitePaddingSame
? (imageSize + stride - 1) / stride
: padding == kTfLitePaddingValid
- ? (imageSize - filterSize + stride) / stride
+ ? (imageSize - effectiveFilterSize + stride) / stride
: 0;
};
- int outWidth = computeOutSize(width, filter_width, params->stride_width);
- int outHeight = computeOutSize(height, filter_height, params->stride_height);
+ int outWidth = computeOutSize(width, filter_width, params->stride_width,
+ params->dilation_width_factor);
+ int outHeight = computeOutSize(height, filter_height, params->stride_height,
+ params->dilation_height_factor);
data->padding.height =
- ComputePadding(params->stride_height, height, filter_height, outHeight);
+ ComputePadding(params->stride_height, params->dilation_height_factor,
+ height, filter_height, outHeight);
data->padding.width =
- ComputePadding(params->stride_width, width, filter_width, outWidth);
+ ComputePadding(params->stride_width, params->dilation_width_factor, width,
+ filter_width, outWidth);
TF_LITE_ENSURE(context, hasBias);
@@ -375,28 +380,40 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
float output_activation_min, output_activation_max;
CalculateActivationRangeFloat(params->activation, &output_activation_min,
&output_activation_max);
-
- switch (kernel_type) {
+ KernelType effective_kernel_type;
+ if (((kernel_type == kMultithreadOptimized) ||
+ (kernel_type == kCblasOptimized)) &&
+ ((params->dilation_width_factor != 1) ||
+ (params->dilation_height_factor != 1))) {
+ // kMultithreadOptimized and kCblasOptimized do not support dilation.
+ // Therefore, fallback to optimized.
+ effective_kernel_type = kGenericOptimized;
+ } else {
+ effective_kernel_type = kernel_type;
+ }
+ switch (effective_kernel_type) {
case kReference: {
- reference_ops::Conv(GetTensorData<float>(input), GetTensorDims(input),
- GetTensorData<float>(filter), GetTensorDims(filter),
- GetTensorData<float>(bias), GetTensorDims(bias),
- params->stride_width, params->stride_height, 1, 1,
- data->padding.width, data->padding.height,
- output_activation_min, output_activation_max,
- GetTensorData<float>(output), GetTensorDims(output),
- GetTensorData<float>(im2col), GetTensorDims(im2col));
+ reference_ops::Conv(
+ GetTensorData<float>(input), GetTensorDims(input),
+ GetTensorData<float>(filter), GetTensorDims(filter),
+ GetTensorData<float>(bias), GetTensorDims(bias), params->stride_width,
+ params->stride_height, params->dilation_width_factor,
+ params->dilation_height_factor, data->padding.width,
+ data->padding.height, output_activation_min, output_activation_max,
+ GetTensorData<float>(output), GetTensorDims(output),
+ GetTensorData<float>(im2col), GetTensorDims(im2col));
break;
}
case kGenericOptimized: {
- optimized_ops::Conv(GetTensorData<float>(input), GetTensorDims(input),
- GetTensorData<float>(filter), GetTensorDims(filter),
- GetTensorData<float>(bias), GetTensorDims(bias),
- params->stride_width, params->stride_height, 1, 1,
- data->padding.width, data->padding.height,
- output_activation_min, output_activation_max,
- GetTensorData<float>(output), GetTensorDims(output),
- GetTensorData<float>(im2col), GetTensorDims(im2col));
+ optimized_ops::Conv(
+ GetTensorData<float>(input), GetTensorDims(input),
+ GetTensorData<float>(filter), GetTensorDims(filter),
+ GetTensorData<float>(bias), GetTensorDims(bias), params->stride_width,
+ params->stride_height, params->dilation_width_factor,
+ params->dilation_height_factor, data->padding.width,
+ data->padding.height, output_activation_min, output_activation_max,
+ GetTensorData<float>(output), GetTensorDims(output),
+ GetTensorData<float>(im2col), GetTensorDims(im2col));
break;
}
case kMultithreadOptimized: {