aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/conv.cc
diff options
context:
space:
mode:
authorGravatar Suharsh Sivakumar <suharshs@google.com>2018-08-14 21:22:35 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-14 21:26:54 -0700
commit7109b7267921f761636acdea7a03d5b212653f2b (patch)
treee0343c373caf3882377e44be94244dc09adff2b9 /tensorflow/contrib/lite/kernels/conv.cc
parentd955bd55bb6138e908fe047152a1e1ac3f278aa9 (diff)
Quantized Dilated Convolution support.
Also add tests for float and quantized dilated support, which seemed to be missing. PiperOrigin-RevId: 208766032
Diffstat (limited to 'tensorflow/contrib/lite/kernels/conv.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/conv.cc48
1 files changed, 31 insertions, 17 deletions
diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc
index 04c0263b78..50fe5c2e04 100644
--- a/tensorflow/contrib/lite/kernels/conv.cc
+++ b/tensorflow/contrib/lite/kernels/conv.cc
@@ -334,18 +334,31 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
auto filter_offset = -filter->params.zero_point;
auto output_offset = output->params.zero_point;
- switch (kernel_type) {
+ KernelType effective_kernel_type;
+ if ((kernel_type == kMultithreadOptimized ||
+ kernel_type == kCblasOptimized) &&
+ (params->dilation_width_factor != 1 ||
+ params->dilation_height_factor != 1)) {
+ // kMultithreadOptimized and kCblasOptimized do not support dilation.
+ // Therefore, fallback to optimized.
+ effective_kernel_type = kGenericOptimized;
+ } else {
+ effective_kernel_type = kernel_type;
+ }
+
+ switch (effective_kernel_type) {
case kReference:
reference_ops::Conv(
GetTensorData<uint8_t>(input), GetTensorDims(input), input_offset,
GetTensorData<uint8_t>(filter), GetTensorDims(filter), filter_offset,
GetTensorData<int32_t>(bias), GetTensorDims(bias),
- params->stride_width, params->stride_height, data->padding.width,
- data->padding.height, output_offset, data->output_multiplier,
- data->output_shift, data->output_activation_min,
- data->output_activation_max, GetTensorData<uint8_t>(output),
- GetTensorDims(output), GetTensorData<uint8_t>(im2col),
- GetTensorDims(im2col), gemm_context);
+ params->stride_width, params->stride_height,
+ params->dilation_width_factor, params->dilation_height_factor,
+ data->padding.width, data->padding.height, output_offset,
+ data->output_multiplier, data->output_shift,
+ data->output_activation_min, data->output_activation_max,
+ GetTensorData<uint8_t>(output), GetTensorDims(output),
+ GetTensorData<uint8_t>(im2col), GetTensorDims(im2col), gemm_context);
break;
case kGenericOptimized:
case kMultithreadOptimized:
@@ -355,12 +368,13 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
GetTensorData<uint8_t>(input), GetTensorDims(input), input_offset,
GetTensorData<uint8_t>(filter), GetTensorDims(filter), filter_offset,
GetTensorData<int32_t>(bias), GetTensorDims(bias),
- params->stride_width, params->stride_height, data->padding.width,
- data->padding.height, output_offset, data->output_multiplier,
- data->output_shift, data->output_activation_min,
- data->output_activation_max, GetTensorData<uint8_t>(output),
- GetTensorDims(output), GetTensorData<uint8_t>(im2col),
- GetTensorDims(im2col), gemm_context);
+ params->stride_width, params->stride_height,
+ params->dilation_width_factor, params->dilation_height_factor,
+ data->padding.width, data->padding.height, output_offset,
+ data->output_multiplier, data->output_shift,
+ data->output_activation_min, data->output_activation_max,
+ GetTensorData<uint8_t>(output), GetTensorDims(output),
+ GetTensorData<uint8_t>(im2col), GetTensorDims(im2col), gemm_context);
break;
}
}
@@ -374,10 +388,10 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
CalculateActivationRange(params->activation, &output_activation_min,
&output_activation_max);
KernelType effective_kernel_type;
- if (((kernel_type == kMultithreadOptimized) ||
- (kernel_type == kCblasOptimized)) &&
- ((params->dilation_width_factor != 1) ||
- (params->dilation_height_factor != 1))) {
+ if ((kernel_type == kMultithreadOptimized ||
+ kernel_type == kCblasOptimized) &&
+ (params->dilation_width_factor != 1 ||
+ params->dilation_height_factor != 1)) {
// kMultithreadOptimized and kCblasOptimized do not support dilation.
// Therefore, fallback to optimized.
effective_kernel_type = kGenericOptimized;