diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/convolution_thunk.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/convolution_thunk.cc | 53 |
1 files changed, 12 insertions, 41 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index 05448d863d..3a23ac1d63 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/platform/logging.h" @@ -30,62 +31,32 @@ namespace gpu { using se::dnn::AlgorithmDesc; -ConvolutionThunk::ConvolutionThunk( - CudnnConvKind convolution_kind, const BufferAllocation::Slice& input_buffer, - const BufferAllocation::Slice& filter_buffer, - const BufferAllocation::Slice& output_buffer, - const BufferAllocation::Slice& tuple_result_buffer, - const BufferAllocation::Slice& scratch_buffer, const Shape& input_shape, - const Shape& filter_shape, const Shape& output_shape, const Window& window, - const ConvolutionDimensionNumbers& dim_nums, int64 feature_group_count, - int64 algorithm, bool tensor_ops_enabled, const HloInstruction* hlo) - : Thunk(Kind::kConvolution, hlo), - convolution_kind_(convolution_kind), - input_buffer_(input_buffer), - filter_buffer_(filter_buffer), - output_buffer_(output_buffer), - tuple_result_buffer_(tuple_result_buffer), - scratch_buffer_(scratch_buffer), - input_shape_(input_shape), - filter_shape_(filter_shape), - output_shape_(output_shape), - window_(window), - dim_nums_(dim_nums), - feature_group_count_(feature_group_count), - algorithm_(algorithm), - tensor_ops_enabled_(tensor_ops_enabled) {} - Status ConvolutionThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream, HloExecutionProfiler* profiler) { - se::DeviceMemoryBase input_data = - buffer_allocations.GetDeviceAddress(input_buffer_); - se::DeviceMemoryBase filter_data = - buffer_allocations.GetDeviceAddress(filter_buffer_); - se::DeviceMemoryBase output_data = - buffer_allocations.GetDeviceAddress(output_buffer_); + CudnnConvParams params; + + params.input_buf = buffer_allocations.GetDeviceAddress(input_buffer_); + params.filter_buf = buffer_allocations.GetDeviceAddress(filter_buffer_); + params.output_buf = buffer_allocations.GetDeviceAddress(output_buffer_); se::DeviceMemoryBase scratch = buffer_allocations.GetDeviceAddress(scratch_buffer_); - se::dnn::AlgorithmConfig algorithm_config( - se::dnn::AlgorithmDesc(algorithm_, tensor_ops_enabled_)); + TF_RETURN_IF_ERROR(PopulateCudnnConvParams(cudnn_call_, ¶ms)); auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); - TF_RETURN_IF_ERROR(RunCudnnConvolution( - convolution_kind_, input_shape_, filter_shape_, output_shape_, input_data, - filter_data, output_data, scratch, window_, dim_nums_, - feature_group_count_, algorithm_config, stream)); + TF_RETURN_IF_ERROR(RunCudnnConvolution(params, scratch, stream)); // Figure out which of output/input/filter is the result produced by // this op, and write the result tuple. void* result_ptr = [&] { - switch (convolution_kind_) { + switch (params.kind) { case CudnnConvKind::kForward: - return output_data.opaque(); + return params.output_buf.opaque(); case CudnnConvKind::kBackwardInput: - return input_data.opaque(); + return params.input_buf.opaque(); case CudnnConvKind::kBackwardFilter: - return filter_data.opaque(); + return params.filter_buf.opaque(); } }(); void* ptrs[] = {result_ptr, scratch.opaque()}; |