diff options
author | Tim Shen <timshen@google.com> | 2018-09-19 12:40:56 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-19 12:44:07 -0700 |
commit | ff11877b101fe9c19021e8d7b43841031eb71cc3 (patch) | |
tree | 94ab2b04488b91b3b62dc3976a255345d0c9ae02 /tensorflow/compiler/xla/service/gpu | |
parent | 0b97c406413fbf71897b28461b35470f3f14fd7e (diff) |
Simplify ir_emitter_unnested so that it doesn't take a look at conv
custom call and try to understand what's inside. convolution_thunk does
it anyway.
PiperOrigin-RevId: 213676051
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu')
3 files changed, 49 insertions, 49 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index 3a23ac1d63..85f3682a5a 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -29,21 +29,51 @@ limitations under the License. namespace xla { namespace gpu { -using se::dnn::AlgorithmDesc; +ConvolutionThunk::ConvolutionThunk( + const HloCustomCallInstruction* cudnn_call, + std::vector<BufferAllocation::Slice> operand_slices, + BufferAllocation::Slice result_slice, BufferAllocation::Slice scratch_slice, + BufferAllocation::Slice tuple_result_slice) + : Thunk(Kind::kConvolution, cudnn_call), + cudnn_call_(cudnn_call), + operand_buffers_(std::move(operand_slices)), + result_buffer_(result_slice), + scratch_buffer_(scratch_slice), + tuple_result_buffer_(tuple_result_slice) {} Status ConvolutionThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream, HloExecutionProfiler* profiler) { CudnnConvParams params; + TF_RETURN_IF_ERROR(PopulateCudnnConvParams(cudnn_call_, ¶ms)); + + switch (params.kind) { + case CudnnConvKind::kForward: + params.input_buf = + buffer_allocations.GetDeviceAddress(operand_buffers_[0]); + params.filter_buf = + buffer_allocations.GetDeviceAddress(operand_buffers_[1]); + params.output_buf = buffer_allocations.GetDeviceAddress(result_buffer_); + break; + case CudnnConvKind::kBackwardInput: + params.input_buf = buffer_allocations.GetDeviceAddress(result_buffer_); + params.filter_buf = + buffer_allocations.GetDeviceAddress(operand_buffers_[1]); + params.output_buf = + buffer_allocations.GetDeviceAddress(operand_buffers_[0]); + break; + case CudnnConvKind::kBackwardFilter: + params.input_buf = + buffer_allocations.GetDeviceAddress(operand_buffers_[0]); + params.filter_buf = buffer_allocations.GetDeviceAddress(result_buffer_); + params.output_buf = + buffer_allocations.GetDeviceAddress(operand_buffers_[1]); + break; + } - params.input_buf = buffer_allocations.GetDeviceAddress(input_buffer_); - params.filter_buf = buffer_allocations.GetDeviceAddress(filter_buffer_); - params.output_buf = buffer_allocations.GetDeviceAddress(output_buffer_); se::DeviceMemoryBase scratch = buffer_allocations.GetDeviceAddress(scratch_buffer_); - TF_RETURN_IF_ERROR(PopulateCudnnConvParams(cudnn_call_, ¶ms)); - auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); TF_RETURN_IF_ERROR(RunCudnnConvolution(params, scratch, stream)); diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h index d7d1f91fba..f53bc54198 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h @@ -42,24 +42,12 @@ class ConvolutionThunk : public Thunk { // Constructs a thunk for launching a DNN convolution. When run, it will // write a tuple (result, scratch_memory) into `tuple_result_buffer`. // - // Note that "output" here doesn't refer to the output from running this - // thunk, but rather to the "output" of a hypothetical forward convolution - // that corresponds to this input+filter+output triple. That is, the result - // generated by this thunk is "output" for forward convs, "input" for - // backward-input convs, and "filter" for backward-filter convs. + // operand_slices should be in the same order as cudnn_call->operands(). ConvolutionThunk(const HloCustomCallInstruction* cudnn_call, - BufferAllocation::Slice input_slice, - BufferAllocation::Slice filter_slice, - BufferAllocation::Slice output_slice, + std::vector<BufferAllocation::Slice> operand_slices, + BufferAllocation::Slice result_slice, BufferAllocation::Slice scratch_slice, - BufferAllocation::Slice tuple_result_slice) - : Thunk(Kind::kConvolution, cudnn_call), - cudnn_call_(cudnn_call), - input_buffer_(std::move(input_slice)), - filter_buffer_(std::move(filter_slice)), - output_buffer_(std::move(output_slice)), - scratch_buffer_(std::move(scratch_slice)), - tuple_result_buffer_(std::move(tuple_result_slice)) {} + BufferAllocation::Slice tuple_result_slice); ConvolutionThunk(const ConvolutionThunk&) = delete; ConvolutionThunk& operator=(const ConvolutionThunk&) = delete; @@ -71,9 +59,8 @@ class ConvolutionThunk : public Thunk { private: const HloCustomCallInstruction* cudnn_call_; - BufferAllocation::Slice input_buffer_; - BufferAllocation::Slice filter_buffer_; - BufferAllocation::Slice output_buffer_; + std::vector<BufferAllocation::Slice> operand_buffers_; + BufferAllocation::Slice result_buffer_; BufferAllocation::Slice scratch_buffer_; BufferAllocation::Slice tuple_result_buffer_; }; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index b669881026..c792dd2ddb 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -465,35 +465,18 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { if (IsCustomCallToDnnConvolution(*custom_call)) { const auto& assn = ir_emitter_context_->buffer_assignment(); - auto lhs_slice = GetAllocationSlice(*custom_call->operand(0)); - auto rhs_slice = GetAllocationSlice(*custom_call->operand(1)); + std::vector<BufferAllocation::Slice> operand_slices; + operand_slices.reserve(custom_call->operand_count()); + for (const auto* operand : custom_call->operands()) { + operand_slices.push_back(GetAllocationSlice(*operand)); + } auto tuple_result_slice = GetAllocationSlice(*custom_call); auto conv_result_slice = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie(); auto scratch_slice = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); - const auto& target = custom_call->custom_call_target(); - BufferAllocation::Slice input_slice, filter_slice, output_slice; - - if (target == kCudnnConvForwardCallTarget) { - input_slice = lhs_slice; - filter_slice = rhs_slice; - output_slice = conv_result_slice; - } else if (target == kCudnnConvBackwardInputCallTarget) { - input_slice = conv_result_slice; - filter_slice = rhs_slice; - output_slice = lhs_slice; - } else if (target == kCudnnConvBackwardFilterCallTarget) { - input_slice = lhs_slice; - filter_slice = conv_result_slice; - output_slice = rhs_slice; - } else { - LOG(FATAL) << "Unexpected custom call target: " - << custom_call->custom_call_target(); - } - thunk_sequence_->emplace_back(absl::make_unique<ConvolutionThunk>( - Cast<HloCustomCallInstruction>(custom_call), input_slice, filter_slice, - output_slice, scratch_slice, tuple_result_slice)); + Cast<HloCustomCallInstruction>(custom_call), std::move(operand_slices), + conv_result_slice, scratch_slice, tuple_result_slice)); return Status::OK(); } |