aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
diff options
context:
space:
mode:
authorGravatar Tim Shen <timshen@google.com>2018-09-10 16:59:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-10 17:14:29 -0700
commitfea74706aaa314cc77ec66c2c986365590e8df27 (patch)
treecc2e225fc3e7dd94efc23ecd4472104ed24987b8 /tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
parentc277998e9f82660b1573fd5587780a97db761a65 (diff)
Cleanup cudnn_convolution_runner's interface. Use a struct to pack most
of the parameters, so that it's easier to toss them around. PiperOrigin-RevId: 212361326
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc81
1 files changed, 33 insertions, 48 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
index 05125e9d1f..2a86ac265e 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
@@ -72,14 +72,22 @@ class ScratchBufAllocator : public se::ScratchAllocator {
};
template <typename T>
-Status RunCudnnConvolution(
- CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
- const Shape& output_shape, DeviceMemory<T> input_buf,
- DeviceMemory<T> filter_buf, DeviceMemory<T> output_buf,
- se::ScratchAllocator* scratch_allocator, const Window& window,
- const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
- AlgorithmConfig algorithm, Stream* stream,
- ProfileResult* profile_result /*= nullptr*/) {
+Status RunCudnnConvolutionImpl(CudnnConvParams params,
+ se::ScratchAllocator* scratch_allocator,
+ se::Stream* stream,
+ se::dnn::ProfileResult* profile_result) {
+ CudnnConvKind kind = params.kind;
+ const Shape& input_shape = *params.input_shape;
+ const Shape& filter_shape = *params.filter_shape;
+ const Shape& output_shape = *params.output_shape;
+ DeviceMemory<T> input_buf(params.input_buf);
+ DeviceMemory<T> filter_buf(params.filter_buf);
+ DeviceMemory<T> output_buf(params.output_buf);
+ const Window& window = *params.window;
+ const ConvolutionDimensionNumbers& dnums = *params.dnums;
+ int64 feature_group_count = params.feature_group_count;
+ AlgorithmConfig algorithm = params.algorithm;
+
VLOG(3) << "Convolution Algorithm: " << algorithm.algorithm().algo_id();
VLOG(3) << "tensor_ops_enabled: "
<< algorithm.algorithm().tensor_ops_enabled();
@@ -219,54 +227,31 @@ string CudnnConvKindToString(CudnnConvKind kind) {
}
}
-Status RunCudnnConvolution(
- CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
- const Shape& output_shape, se::DeviceMemoryBase input_buf,
- se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf,
- se::DeviceMemoryBase scratch_buf, const Window& window,
- const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
- se::dnn::AlgorithmConfig algorithm, se::Stream* stream,
- se::dnn::ProfileResult* profile_result) {
+Status RunCudnnConvolution(CudnnConvParams params,
+ se::DeviceMemoryBase scratch_buf, se::Stream* stream,
+ se::dnn::ProfileResult* profile_result) {
ScratchBufAllocator scratch_allocator(scratch_buf);
- return RunCudnnConvolution(
- kind, input_shape, filter_shape, output_shape, input_buf, filter_buf,
- output_buf, &scratch_allocator, window, dnums, feature_group_count,
- algorithm, stream, profile_result);
+ return RunCudnnConvolution(params, &scratch_allocator, stream,
+ profile_result);
}
-Status RunCudnnConvolution(
- CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
- const Shape& output_shape, se::DeviceMemoryBase input_buf,
- se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf,
- se::ScratchAllocator* scratch_allocator, const Window& window,
- const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
- se::dnn::AlgorithmConfig algorithm, se::Stream* stream,
- se::dnn::ProfileResult* profile_result) {
- PrimitiveType output_primitive_type = output_shape.element_type();
+Status RunCudnnConvolution(CudnnConvParams params,
+ se::ScratchAllocator* scratch_allocator,
+ se::Stream* stream,
+ se::dnn::ProfileResult* profile_result) {
+ PrimitiveType output_primitive_type = params.output_shape->element_type();
switch (output_primitive_type) {
case F16:
- return RunCudnnConvolution(
- kind, input_shape, filter_shape, output_shape,
- se::DeviceMemory<Eigen::half>(input_buf),
- se::DeviceMemory<Eigen::half>(filter_buf),
- se::DeviceMemory<Eigen::half>(output_buf), scratch_allocator, window,
- dnums, feature_group_count, algorithm, stream, profile_result);
+ return RunCudnnConvolutionImpl<Eigen::half>(params, scratch_allocator,
+ stream, profile_result);
case F32:
- return RunCudnnConvolution(
- kind, input_shape, filter_shape, output_shape,
- se::DeviceMemory<float>(input_buf),
- se::DeviceMemory<float>(filter_buf),
- se::DeviceMemory<float>(output_buf), scratch_allocator, window, dnums,
- feature_group_count, algorithm, stream, profile_result);
+ return RunCudnnConvolutionImpl<float>(params, scratch_allocator, stream,
+ profile_result);
case F64:
- return RunCudnnConvolution(
- kind, input_shape, filter_shape, output_shape,
- se::DeviceMemory<double>(input_buf),
- se::DeviceMemory<double>(filter_buf),
- se::DeviceMemory<double>(output_buf), scratch_allocator, window,
- dnums, feature_group_count, algorithm, stream, profile_result);
+ return RunCudnnConvolutionImpl<double>(params, scratch_allocator, stream,
+ profile_result);
default:
- LOG(FATAL) << ShapeUtil::HumanString(output_shape);
+ LOG(FATAL) << ShapeUtil::HumanString(*params.output_shape);
}
}