diff options
author | Suharsh Sivakumar <suharshs@google.com> | 2018-08-14 21:22:35 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-14 21:26:54 -0700 |
commit | 7109b7267921f761636acdea7a03d5b212653f2b (patch) | |
tree | e0343c373caf3882377e44be94244dc09adff2b9 /tensorflow/contrib/lite/kernels/conv.cc | |
parent | d955bd55bb6138e908fe047152a1e1ac3f278aa9 (diff) |
Quantized Dilated Convolution support.
Also add tests for float and quantized dilated support, which seemed to be missing.
PiperOrigin-RevId: 208766032
Diffstat (limited to 'tensorflow/contrib/lite/kernels/conv.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/conv.cc | 48 |
1 files changed, 31 insertions, 17 deletions
diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc index 04c0263b78..50fe5c2e04 100644 --- a/tensorflow/contrib/lite/kernels/conv.cc +++ b/tensorflow/contrib/lite/kernels/conv.cc @@ -334,18 +334,31 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node, auto filter_offset = -filter->params.zero_point; auto output_offset = output->params.zero_point; - switch (kernel_type) { + KernelType effective_kernel_type; + if ((kernel_type == kMultithreadOptimized || + kernel_type == kCblasOptimized) && + (params->dilation_width_factor != 1 || + params->dilation_height_factor != 1)) { + // kMultithreadOptimized and kCblasOptimized do not support dilation. + // Therefore, fallback to optimized. + effective_kernel_type = kGenericOptimized; + } else { + effective_kernel_type = kernel_type; + } + + switch (effective_kernel_type) { case kReference: reference_ops::Conv( GetTensorData<uint8_t>(input), GetTensorDims(input), input_offset, GetTensorData<uint8_t>(filter), GetTensorDims(filter), filter_offset, GetTensorData<int32_t>(bias), GetTensorDims(bias), - params->stride_width, params->stride_height, data->padding.width, - data->padding.height, output_offset, data->output_multiplier, - data->output_shift, data->output_activation_min, - data->output_activation_max, GetTensorData<uint8_t>(output), - GetTensorDims(output), GetTensorData<uint8_t>(im2col), - GetTensorDims(im2col), gemm_context); + params->stride_width, params->stride_height, + params->dilation_width_factor, params->dilation_height_factor, + data->padding.width, data->padding.height, output_offset, + data->output_multiplier, data->output_shift, + data->output_activation_min, data->output_activation_max, + GetTensorData<uint8_t>(output), GetTensorDims(output), + GetTensorData<uint8_t>(im2col), GetTensorDims(im2col), gemm_context); break; case kGenericOptimized: case kMultithreadOptimized: @@ -355,12 +368,13 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node, GetTensorData<uint8_t>(input), GetTensorDims(input), input_offset, GetTensorData<uint8_t>(filter), GetTensorDims(filter), filter_offset, GetTensorData<int32_t>(bias), GetTensorDims(bias), - params->stride_width, params->stride_height, data->padding.width, - data->padding.height, output_offset, data->output_multiplier, - data->output_shift, data->output_activation_min, - data->output_activation_max, GetTensorData<uint8_t>(output), - GetTensorDims(output), GetTensorData<uint8_t>(im2col), - GetTensorDims(im2col), gemm_context); + params->stride_width, params->stride_height, + params->dilation_width_factor, params->dilation_height_factor, + data->padding.width, data->padding.height, output_offset, + data->output_multiplier, data->output_shift, + data->output_activation_min, data->output_activation_max, + GetTensorData<uint8_t>(output), GetTensorDims(output), + GetTensorData<uint8_t>(im2col), GetTensorDims(im2col), gemm_context); break; } } @@ -374,10 +388,10 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node, CalculateActivationRange(params->activation, &output_activation_min, &output_activation_max); KernelType effective_kernel_type; - if (((kernel_type == kMultithreadOptimized) || - (kernel_type == kCblasOptimized)) && - ((params->dilation_width_factor != 1) || - (params->dilation_height_factor != 1))) { + if ((kernel_type == kMultithreadOptimized || + kernel_type == kCblasOptimized) && + (params->dilation_width_factor != 1 || + params->dilation_height_factor != 1)) { // kMultithreadOptimized and kCblasOptimized do not support dilation. // Therefore, fallback to optimized. effective_kernel_type = kGenericOptimized; |