aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
diff options
context:
space:
mode:
authorGravatar Adrian Kuegel <akuegel@google.com>2018-09-03 04:52:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-03 04:57:16 -0700
commit0f4ad7ff0e5ce38bc09ddd008e4f32d2af321495 (patch)
treecc76401cd5057d454fc0dc82178c3747a122cb3f /tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
parent44c884dc5d02abc7c50abea24c8caee6dcadda9a (diff)
Call Cudnn also for grouped convolutions.
Cudnn supports grouped convolutions, so we don't need the ConvolutionFeatureGroupConverter pass and can instead set the group_count parameter on the cudnn custom calls. PiperOrigin-RevId: 211339551
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc54
1 files changed, 28 insertions, 26 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
index 07b96fbd3f..05125e9d1f 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
@@ -77,8 +77,9 @@ Status RunCudnnConvolution(
const Shape& output_shape, DeviceMemory<T> input_buf,
DeviceMemory<T> filter_buf, DeviceMemory<T> output_buf,
se::ScratchAllocator* scratch_allocator, const Window& window,
- const ConvolutionDimensionNumbers& dnums, AlgorithmConfig algorithm,
- Stream* stream, ProfileResult* profile_result /*= nullptr*/) {
+ const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
+ AlgorithmConfig algorithm, Stream* stream,
+ ProfileResult* profile_result /*= nullptr*/) {
VLOG(3) << "Convolution Algorithm: " << algorithm.algorithm().algo_id();
VLOG(3) << "tensor_ops_enabled: "
<< algorithm.algorithm().tensor_ops_enabled();
@@ -144,6 +145,7 @@ Status RunCudnnConvolution(
}
ConvolutionDescriptor convolution_descriptor(effective_num_dimensions);
+ convolution_descriptor.set_group_count(feature_group_count);
for (int dim = 0; dim < num_dimensions; ++dim) {
convolution_descriptor
.set_zero_padding(
@@ -222,14 +224,14 @@ Status RunCudnnConvolution(
const Shape& output_shape, se::DeviceMemoryBase input_buf,
se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf,
se::DeviceMemoryBase scratch_buf, const Window& window,
- const ConvolutionDimensionNumbers& dnums,
+ const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
se::dnn::AlgorithmConfig algorithm, se::Stream* stream,
se::dnn::ProfileResult* profile_result) {
ScratchBufAllocator scratch_allocator(scratch_buf);
- return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
- input_buf, filter_buf, output_buf,
- &scratch_allocator, window, dnums, algorithm,
- stream, profile_result);
+ return RunCudnnConvolution(
+ kind, input_shape, filter_shape, output_shape, input_buf, filter_buf,
+ output_buf, &scratch_allocator, window, dnums, feature_group_count,
+ algorithm, stream, profile_result);
}
Status RunCudnnConvolution(
@@ -237,32 +239,32 @@ Status RunCudnnConvolution(
const Shape& output_shape, se::DeviceMemoryBase input_buf,
se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf,
se::ScratchAllocator* scratch_allocator, const Window& window,
- const ConvolutionDimensionNumbers& dnums,
+ const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
se::dnn::AlgorithmConfig algorithm, se::Stream* stream,
se::dnn::ProfileResult* profile_result) {
PrimitiveType output_primitive_type = output_shape.element_type();
switch (output_primitive_type) {
case F16:
- return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
- se::DeviceMemory<Eigen::half>(input_buf),
- se::DeviceMemory<Eigen::half>(filter_buf),
- se::DeviceMemory<Eigen::half>(output_buf),
- scratch_allocator, window, dnums, algorithm,
- stream, profile_result);
+ return RunCudnnConvolution(
+ kind, input_shape, filter_shape, output_shape,
+ se::DeviceMemory<Eigen::half>(input_buf),
+ se::DeviceMemory<Eigen::half>(filter_buf),
+ se::DeviceMemory<Eigen::half>(output_buf), scratch_allocator, window,
+ dnums, feature_group_count, algorithm, stream, profile_result);
case F32:
- return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
- se::DeviceMemory<float>(input_buf),
- se::DeviceMemory<float>(filter_buf),
- se::DeviceMemory<float>(output_buf),
- scratch_allocator, window, dnums, algorithm,
- stream, profile_result);
+ return RunCudnnConvolution(
+ kind, input_shape, filter_shape, output_shape,
+ se::DeviceMemory<float>(input_buf),
+ se::DeviceMemory<float>(filter_buf),
+ se::DeviceMemory<float>(output_buf), scratch_allocator, window, dnums,
+ feature_group_count, algorithm, stream, profile_result);
case F64:
- return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
- se::DeviceMemory<double>(input_buf),
- se::DeviceMemory<double>(filter_buf),
- se::DeviceMemory<double>(output_buf),
- scratch_allocator, window, dnums, algorithm,
- stream, profile_result);
+ return RunCudnnConvolution(
+ kind, input_shape, filter_shape, output_shape,
+ se::DeviceMemory<double>(input_buf),
+ se::DeviceMemory<double>(filter_buf),
+ se::DeviceMemory<double>(output_buf), scratch_allocator, window,
+ dnums, feature_group_count, algorithm, stream, profile_result);
default:
LOG(FATAL) << ShapeUtil::HumanString(output_shape);
}