aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Tim Shen <timshen@google.com>2018-09-10 16:59:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-10 17:14:29 -0700
commitfea74706aaa314cc77ec66c2c986365590e8df27 (patch)
treecc2e225fc3e7dd94efc23ecd4472104ed24987b8
parentc277998e9f82660b1573fd5587780a97db761a65 (diff)
Cleanup cudnn_convolution_runner's interface. Use a struct to pack most
of the parameters, so that it's easier to toss them around. PiperOrigin-RevId: 212361326
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_thunk.cc7
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc8
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc81
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h44
4 files changed, 65 insertions, 75 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
index 05448d863d..9b567cf4a8 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
@@ -72,9 +72,10 @@ Status ConvolutionThunk::ExecuteOnStream(
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
TF_RETURN_IF_ERROR(RunCudnnConvolution(
- convolution_kind_, input_shape_, filter_shape_, output_shape_, input_data,
- filter_data, output_data, scratch, window_, dim_nums_,
- feature_group_count_, algorithm_config, stream));
+ {convolution_kind_, &input_shape_, &filter_shape_, &output_shape_,
+ input_data, filter_data, output_data, &window_, &dim_nums_,
+ feature_group_count_, algorithm_config},
+ scratch, stream));
// Figure out which of output/input/filter is the result produced by
// this op, and write the result tuple.
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
index 5c2555148a..8fcff84173 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
@@ -295,10 +295,10 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
<< instr->ToString();
bool launch_ok =
- RunCudnnConvolution(
- kind, input_shape, filter_shape, output_shape, input_buf,
- filter_buf, output_buf, &scratch_allocator, window, dnums,
- feature_group_count, AlgorithmConfig(alg), &stream, &profile_result)
+ RunCudnnConvolution({kind, &input_shape, &filter_shape, &output_shape,
+ input_buf, filter_buf, output_buf, &window, &dnums,
+ feature_group_count, AlgorithmConfig(alg)},
+ &scratch_allocator, &stream, &profile_result)
.ok();
if (launch_ok && profile_result.is_valid()) {
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
index 05125e9d1f..2a86ac265e 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
@@ -72,14 +72,22 @@ class ScratchBufAllocator : public se::ScratchAllocator {
};
template <typename T>
-Status RunCudnnConvolution(
- CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
- const Shape& output_shape, DeviceMemory<T> input_buf,
- DeviceMemory<T> filter_buf, DeviceMemory<T> output_buf,
- se::ScratchAllocator* scratch_allocator, const Window& window,
- const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
- AlgorithmConfig algorithm, Stream* stream,
- ProfileResult* profile_result /*= nullptr*/) {
+Status RunCudnnConvolutionImpl(CudnnConvParams params,
+ se::ScratchAllocator* scratch_allocator,
+ se::Stream* stream,
+ se::dnn::ProfileResult* profile_result) {
+ CudnnConvKind kind = params.kind;
+ const Shape& input_shape = *params.input_shape;
+ const Shape& filter_shape = *params.filter_shape;
+ const Shape& output_shape = *params.output_shape;
+ DeviceMemory<T> input_buf(params.input_buf);
+ DeviceMemory<T> filter_buf(params.filter_buf);
+ DeviceMemory<T> output_buf(params.output_buf);
+ const Window& window = *params.window;
+ const ConvolutionDimensionNumbers& dnums = *params.dnums;
+ int64 feature_group_count = params.feature_group_count;
+ AlgorithmConfig algorithm = params.algorithm;
+
VLOG(3) << "Convolution Algorithm: " << algorithm.algorithm().algo_id();
VLOG(3) << "tensor_ops_enabled: "
<< algorithm.algorithm().tensor_ops_enabled();
@@ -219,54 +227,31 @@ string CudnnConvKindToString(CudnnConvKind kind) {
}
}
-Status RunCudnnConvolution(
- CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
- const Shape& output_shape, se::DeviceMemoryBase input_buf,
- se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf,
- se::DeviceMemoryBase scratch_buf, const Window& window,
- const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
- se::dnn::AlgorithmConfig algorithm, se::Stream* stream,
- se::dnn::ProfileResult* profile_result) {
+Status RunCudnnConvolution(CudnnConvParams params,
+ se::DeviceMemoryBase scratch_buf, se::Stream* stream,
+ se::dnn::ProfileResult* profile_result) {
ScratchBufAllocator scratch_allocator(scratch_buf);
- return RunCudnnConvolution(
- kind, input_shape, filter_shape, output_shape, input_buf, filter_buf,
- output_buf, &scratch_allocator, window, dnums, feature_group_count,
- algorithm, stream, profile_result);
+ return RunCudnnConvolution(params, &scratch_allocator, stream,
+ profile_result);
}
-Status RunCudnnConvolution(
- CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
- const Shape& output_shape, se::DeviceMemoryBase input_buf,
- se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf,
- se::ScratchAllocator* scratch_allocator, const Window& window,
- const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
- se::dnn::AlgorithmConfig algorithm, se::Stream* stream,
- se::dnn::ProfileResult* profile_result) {
- PrimitiveType output_primitive_type = output_shape.element_type();
+Status RunCudnnConvolution(CudnnConvParams params,
+ se::ScratchAllocator* scratch_allocator,
+ se::Stream* stream,
+ se::dnn::ProfileResult* profile_result) {
+ PrimitiveType output_primitive_type = params.output_shape->element_type();
switch (output_primitive_type) {
case F16:
- return RunCudnnConvolution(
- kind, input_shape, filter_shape, output_shape,
- se::DeviceMemory<Eigen::half>(input_buf),
- se::DeviceMemory<Eigen::half>(filter_buf),
- se::DeviceMemory<Eigen::half>(output_buf), scratch_allocator, window,
- dnums, feature_group_count, algorithm, stream, profile_result);
+ return RunCudnnConvolutionImpl<Eigen::half>(params, scratch_allocator,
+ stream, profile_result);
case F32:
- return RunCudnnConvolution(
- kind, input_shape, filter_shape, output_shape,
- se::DeviceMemory<float>(input_buf),
- se::DeviceMemory<float>(filter_buf),
- se::DeviceMemory<float>(output_buf), scratch_allocator, window, dnums,
- feature_group_count, algorithm, stream, profile_result);
+ return RunCudnnConvolutionImpl<float>(params, scratch_allocator, stream,
+ profile_result);
case F64:
- return RunCudnnConvolution(
- kind, input_shape, filter_shape, output_shape,
- se::DeviceMemory<double>(input_buf),
- se::DeviceMemory<double>(filter_buf),
- se::DeviceMemory<double>(output_buf), scratch_allocator, window,
- dnums, feature_group_count, algorithm, stream, profile_result);
+ return RunCudnnConvolutionImpl<double>(params, scratch_allocator, stream,
+ profile_result);
default:
- LOG(FATAL) << ShapeUtil::HumanString(output_shape);
+ LOG(FATAL) << ShapeUtil::HumanString(*params.output_shape);
}
}
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h
index a1b4fc71d0..381aa37a1b 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h
@@ -47,6 +47,20 @@ enum class CudnnConvKind {
kBackwardFilter, // input + output => filter
};
+struct CudnnConvParams {
+ CudnnConvKind kind;
+ const Shape* input_shape;
+ const Shape* filter_shape;
+ const Shape* output_shape;
+ se::DeviceMemoryBase input_buf;
+ se::DeviceMemoryBase filter_buf;
+ se::DeviceMemoryBase output_buf;
+ const Window* window;
+ const ConvolutionDimensionNumbers* dnums;
+ int64 feature_group_count;
+ se::dnn::AlgorithmConfig algorithm;
+};
+
// Converts a CudnnConvKind value to a string.
string CudnnConvKindToString(CudnnConvKind kind);
@@ -55,10 +69,9 @@ string CudnnConvKindToString(CudnnConvKind kind);
// Note that depending on the value of CudnnConvKind, the result of this call
// may be written into input_buf, filter_buf, or output_buf!
//
-// At the moment we only support cudnn convolutions over float and half, and
-// convolution with half data type is implemented with cudnn PSEUDO_HALF
-// configuration, that is, the input values are half and the internal
-// computation type is float.
+// At the moment convolution with half data type is implemented with cudnn
+// PSEUDO_HALF configuration, that is, the input values are half and the
+// internal computation type is float.
//
// We provide one overload which takes a scratch buffer, and another which takes
// an allocator which is responsible for allocating the scratch space. In
@@ -70,23 +83,14 @@ string CudnnConvKindToString(CudnnConvKind kind);
// allocator and take note of how much memory is used. The next time you call
// the same conv, you can provide an explicitly preallocated scratch buffer of
// that size, if you like.
-Status RunCudnnConvolution(
- CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
- const Shape& output_shape, se::DeviceMemoryBase input_buf,
- se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf,
- se::DeviceMemoryBase scratch_buf, const Window& window,
- const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
- se::dnn::AlgorithmConfig algorithm, se::Stream* stream,
- se::dnn::ProfileResult* profile_result = nullptr);
+Status RunCudnnConvolution(CudnnConvParams params,
+ se::DeviceMemoryBase scratch_buf, se::Stream* stream,
+ se::dnn::ProfileResult* profile_result = nullptr);
-Status RunCudnnConvolution(
- CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
- const Shape& output_shape, se::DeviceMemoryBase input_buf,
- se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf,
- se::ScratchAllocator* scratch_allocator, const Window& window,
- const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
- se::dnn::AlgorithmConfig algorithm, se::Stream* stream,
- se::dnn::ProfileResult* profile_result = nullptr);
+Status RunCudnnConvolution(CudnnConvParams params,
+ se::ScratchAllocator* scratch_allocator,
+ se::Stream* stream,
+ se::dnn::ProfileResult* profile_result = nullptr);
} // namespace gpu
} // namespace xla