diff options
author | Adrian Kuegel <akuegel@google.com> | 2018-09-03 04:52:24 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-03 04:57:16 -0700 |
commit | 0f4ad7ff0e5ce38bc09ddd008e4f32d2af321495 (patch) | |
tree | cc76401cd5057d454fc0dc82178c3747a122cb3f /tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc | |
parent | 44c884dc5d02abc7c50abea24c8caee6dcadda9a (diff) |
Call Cudnn also for grouped convolutions.
Cudnn supports grouped convolutions, so we don't need the
ConvolutionFeatureGroupConverter pass and can instead set the group_count
parameter on the cudnn custom calls.
PiperOrigin-RevId: 211339551
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc | 54 |
1 files changed, 28 insertions, 26 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc index 07b96fbd3f..05125e9d1f 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc @@ -77,8 +77,9 @@ Status RunCudnnConvolution( const Shape& output_shape, DeviceMemory<T> input_buf, DeviceMemory<T> filter_buf, DeviceMemory<T> output_buf, se::ScratchAllocator* scratch_allocator, const Window& window, - const ConvolutionDimensionNumbers& dnums, AlgorithmConfig algorithm, - Stream* stream, ProfileResult* profile_result /*= nullptr*/) { + const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, + AlgorithmConfig algorithm, Stream* stream, + ProfileResult* profile_result /*= nullptr*/) { VLOG(3) << "Convolution Algorithm: " << algorithm.algorithm().algo_id(); VLOG(3) << "tensor_ops_enabled: " << algorithm.algorithm().tensor_ops_enabled(); @@ -144,6 +145,7 @@ Status RunCudnnConvolution( } ConvolutionDescriptor convolution_descriptor(effective_num_dimensions); + convolution_descriptor.set_group_count(feature_group_count); for (int dim = 0; dim < num_dimensions; ++dim) { convolution_descriptor .set_zero_padding( @@ -222,14 +224,14 @@ Status RunCudnnConvolution( const Shape& output_shape, se::DeviceMemoryBase input_buf, se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf, se::DeviceMemoryBase scratch_buf, const Window& window, - const ConvolutionDimensionNumbers& dnums, + const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, se::dnn::AlgorithmConfig algorithm, se::Stream* stream, se::dnn::ProfileResult* profile_result) { ScratchBufAllocator scratch_allocator(scratch_buf); - return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, - input_buf, filter_buf, output_buf, - &scratch_allocator, window, dnums, algorithm, - stream, profile_result); + return RunCudnnConvolution( + kind, input_shape, filter_shape, output_shape, input_buf, filter_buf, + output_buf, &scratch_allocator, window, dnums, feature_group_count, + algorithm, stream, profile_result); } Status RunCudnnConvolution( @@ -237,32 +239,32 @@ Status RunCudnnConvolution( const Shape& output_shape, se::DeviceMemoryBase input_buf, se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf, se::ScratchAllocator* scratch_allocator, const Window& window, - const ConvolutionDimensionNumbers& dnums, + const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, se::dnn::AlgorithmConfig algorithm, se::Stream* stream, se::dnn::ProfileResult* profile_result) { PrimitiveType output_primitive_type = output_shape.element_type(); switch (output_primitive_type) { case F16: - return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, - se::DeviceMemory<Eigen::half>(input_buf), - se::DeviceMemory<Eigen::half>(filter_buf), - se::DeviceMemory<Eigen::half>(output_buf), - scratch_allocator, window, dnums, algorithm, - stream, profile_result); + return RunCudnnConvolution( + kind, input_shape, filter_shape, output_shape, + se::DeviceMemory<Eigen::half>(input_buf), + se::DeviceMemory<Eigen::half>(filter_buf), + se::DeviceMemory<Eigen::half>(output_buf), scratch_allocator, window, + dnums, feature_group_count, algorithm, stream, profile_result); case F32: - return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, - se::DeviceMemory<float>(input_buf), - se::DeviceMemory<float>(filter_buf), - se::DeviceMemory<float>(output_buf), - scratch_allocator, window, dnums, algorithm, - stream, profile_result); + return RunCudnnConvolution( + kind, input_shape, filter_shape, output_shape, + se::DeviceMemory<float>(input_buf), + se::DeviceMemory<float>(filter_buf), + se::DeviceMemory<float>(output_buf), scratch_allocator, window, dnums, + feature_group_count, algorithm, stream, profile_result); case F64: - return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, - se::DeviceMemory<double>(input_buf), - se::DeviceMemory<double>(filter_buf), - se::DeviceMemory<double>(output_buf), - scratch_allocator, window, dnums, algorithm, - stream, profile_result); + return RunCudnnConvolution( + kind, input_shape, filter_shape, output_shape, + se::DeviceMemory<double>(input_buf), + se::DeviceMemory<double>(filter_buf), + se::DeviceMemory<double>(output_buf), scratch_allocator, window, + dnums, feature_group_count, algorithm, stream, profile_result); default: LOG(FATAL) << ShapeUtil::HumanString(output_shape); } |