diff options
author | Tim Shen <timshen@google.com> | 2018-09-10 16:59:51 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-10 17:14:29 -0700 |
commit | fea74706aaa314cc77ec66c2c986365590e8df27 (patch) | |
tree | cc2e225fc3e7dd94efc23ecd4472104ed24987b8 | |
parent | c277998e9f82660b1573fd5587780a97db761a65 (diff) |
Cleanup cudnn_convolution_runner's interface. Use a struct to pack most
of the parameters, so that it's easier to toss them around.
PiperOrigin-RevId: 212361326
4 files changed, 65 insertions, 75 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index 05448d863d..9b567cf4a8 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -72,9 +72,10 @@ Status ConvolutionThunk::ExecuteOnStream( auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); TF_RETURN_IF_ERROR(RunCudnnConvolution( - convolution_kind_, input_shape_, filter_shape_, output_shape_, input_data, - filter_data, output_data, scratch, window_, dim_nums_, - feature_group_count_, algorithm_config, stream)); + {convolution_kind_, &input_shape_, &filter_shape_, &output_shape_, + input_data, filter_data, output_data, &window_, &dim_nums_, + feature_group_count_, algorithm_config}, + scratch, stream)); // Figure out which of output/input/filter is the result produced by // this op, and write the result tuple. diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc index 5c2555148a..8fcff84173 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc @@ -295,10 +295,10 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( << instr->ToString(); bool launch_ok = - RunCudnnConvolution( - kind, input_shape, filter_shape, output_shape, input_buf, - filter_buf, output_buf, &scratch_allocator, window, dnums, - feature_group_count, AlgorithmConfig(alg), &stream, &profile_result) + RunCudnnConvolution({kind, &input_shape, &filter_shape, &output_shape, + input_buf, filter_buf, output_buf, &window, &dnums, + feature_group_count, AlgorithmConfig(alg)}, + &scratch_allocator, &stream, &profile_result) .ok(); if (launch_ok && profile_result.is_valid()) { diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc index 05125e9d1f..2a86ac265e 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc @@ -72,14 +72,22 @@ class ScratchBufAllocator : public se::ScratchAllocator { }; template <typename T> -Status RunCudnnConvolution( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, DeviceMemory<T> input_buf, - DeviceMemory<T> filter_buf, DeviceMemory<T> output_buf, - se::ScratchAllocator* scratch_allocator, const Window& window, - const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, - AlgorithmConfig algorithm, Stream* stream, - ProfileResult* profile_result /*= nullptr*/) { +Status RunCudnnConvolutionImpl(CudnnConvParams params, + se::ScratchAllocator* scratch_allocator, + se::Stream* stream, + se::dnn::ProfileResult* profile_result) { + CudnnConvKind kind = params.kind; + const Shape& input_shape = *params.input_shape; + const Shape& filter_shape = *params.filter_shape; + const Shape& output_shape = *params.output_shape; + DeviceMemory<T> input_buf(params.input_buf); + DeviceMemory<T> filter_buf(params.filter_buf); + DeviceMemory<T> output_buf(params.output_buf); + const Window& window = *params.window; + const ConvolutionDimensionNumbers& dnums = *params.dnums; + int64 feature_group_count = params.feature_group_count; + AlgorithmConfig algorithm = params.algorithm; + VLOG(3) << "Convolution Algorithm: " << algorithm.algorithm().algo_id(); VLOG(3) << "tensor_ops_enabled: " << algorithm.algorithm().tensor_ops_enabled(); @@ -219,54 +227,31 @@ string CudnnConvKindToString(CudnnConvKind kind) { } } -Status RunCudnnConvolution( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, se::DeviceMemoryBase input_buf, - se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf, - se::DeviceMemoryBase scratch_buf, const Window& window, - const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, - se::dnn::AlgorithmConfig algorithm, se::Stream* stream, - se::dnn::ProfileResult* profile_result) { +Status RunCudnnConvolution(CudnnConvParams params, + se::DeviceMemoryBase scratch_buf, se::Stream* stream, + se::dnn::ProfileResult* profile_result) { ScratchBufAllocator scratch_allocator(scratch_buf); - return RunCudnnConvolution( - kind, input_shape, filter_shape, output_shape, input_buf, filter_buf, - output_buf, &scratch_allocator, window, dnums, feature_group_count, - algorithm, stream, profile_result); + return RunCudnnConvolution(params, &scratch_allocator, stream, + profile_result); } -Status RunCudnnConvolution( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, se::DeviceMemoryBase input_buf, - se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf, - se::ScratchAllocator* scratch_allocator, const Window& window, - const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, - se::dnn::AlgorithmConfig algorithm, se::Stream* stream, - se::dnn::ProfileResult* profile_result) { - PrimitiveType output_primitive_type = output_shape.element_type(); +Status RunCudnnConvolution(CudnnConvParams params, + se::ScratchAllocator* scratch_allocator, + se::Stream* stream, + se::dnn::ProfileResult* profile_result) { + PrimitiveType output_primitive_type = params.output_shape->element_type(); switch (output_primitive_type) { case F16: - return RunCudnnConvolution( - kind, input_shape, filter_shape, output_shape, - se::DeviceMemory<Eigen::half>(input_buf), - se::DeviceMemory<Eigen::half>(filter_buf), - se::DeviceMemory<Eigen::half>(output_buf), scratch_allocator, window, - dnums, feature_group_count, algorithm, stream, profile_result); + return RunCudnnConvolutionImpl<Eigen::half>(params, scratch_allocator, + stream, profile_result); case F32: - return RunCudnnConvolution( - kind, input_shape, filter_shape, output_shape, - se::DeviceMemory<float>(input_buf), - se::DeviceMemory<float>(filter_buf), - se::DeviceMemory<float>(output_buf), scratch_allocator, window, dnums, - feature_group_count, algorithm, stream, profile_result); + return RunCudnnConvolutionImpl<float>(params, scratch_allocator, stream, + profile_result); case F64: - return RunCudnnConvolution( - kind, input_shape, filter_shape, output_shape, - se::DeviceMemory<double>(input_buf), - se::DeviceMemory<double>(filter_buf), - se::DeviceMemory<double>(output_buf), scratch_allocator, window, - dnums, feature_group_count, algorithm, stream, profile_result); + return RunCudnnConvolutionImpl<double>(params, scratch_allocator, stream, + profile_result); default: - LOG(FATAL) << ShapeUtil::HumanString(output_shape); + LOG(FATAL) << ShapeUtil::HumanString(*params.output_shape); } } diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h index a1b4fc71d0..381aa37a1b 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h @@ -47,6 +47,20 @@ enum class CudnnConvKind { kBackwardFilter, // input + output => filter }; +struct CudnnConvParams { + CudnnConvKind kind; + const Shape* input_shape; + const Shape* filter_shape; + const Shape* output_shape; + se::DeviceMemoryBase input_buf; + se::DeviceMemoryBase filter_buf; + se::DeviceMemoryBase output_buf; + const Window* window; + const ConvolutionDimensionNumbers* dnums; + int64 feature_group_count; + se::dnn::AlgorithmConfig algorithm; +}; + // Converts a CudnnConvKind value to a string. string CudnnConvKindToString(CudnnConvKind kind); @@ -55,10 +69,9 @@ string CudnnConvKindToString(CudnnConvKind kind); // Note that depending on the value of CudnnConvKind, the result of this call // may be written into input_buf, filter_buf, or output_buf! // -// At the moment we only support cudnn convolutions over float and half, and -// convolution with half data type is implemented with cudnn PSEUDO_HALF -// configuration, that is, the input values are half and the internal -// computation type is float. +// At the moment convolution with half data type is implemented with cudnn +// PSEUDO_HALF configuration, that is, the input values are half and the +// internal computation type is float. // // We provide one overload which takes a scratch buffer, and another which takes // an allocator which is responsible for allocating the scratch space. In @@ -70,23 +83,14 @@ string CudnnConvKindToString(CudnnConvKind kind); // allocator and take note of how much memory is used. The next time you call // the same conv, you can provide an explicitly preallocated scratch buffer of // that size, if you like. -Status RunCudnnConvolution( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, se::DeviceMemoryBase input_buf, - se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf, - se::DeviceMemoryBase scratch_buf, const Window& window, - const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, - se::dnn::AlgorithmConfig algorithm, se::Stream* stream, - se::dnn::ProfileResult* profile_result = nullptr); +Status RunCudnnConvolution(CudnnConvParams params, + se::DeviceMemoryBase scratch_buf, se::Stream* stream, + se::dnn::ProfileResult* profile_result = nullptr); -Status RunCudnnConvolution( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, se::DeviceMemoryBase input_buf, - se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf, - se::ScratchAllocator* scratch_allocator, const Window& window, - const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, - se::dnn::AlgorithmConfig algorithm, se::Stream* stream, - se::dnn::ProfileResult* profile_result = nullptr); +Status RunCudnnConvolution(CudnnConvParams params, + se::ScratchAllocator* scratch_allocator, + se::Stream* stream, + se::dnn::ProfileResult* profile_result = nullptr); } // namespace gpu } // namespace xla |