aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Tim Shen <timshen@google.com>2018-09-19 12:40:56 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-19 12:44:07 -0700
commitff11877b101fe9c19021e8d7b43841031eb71cc3 (patch)
tree94ab2b04488b91b3b62dc3976a255345d0c9ae02
parent0b97c406413fbf71897b28461b35470f3f14fd7e (diff)
Simplify ir_emitter_unnested so that it doesn't take a look at conv
custom call and try to understand what's inside. convolution_thunk does it anyway. PiperOrigin-RevId: 213676051
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_thunk.cc42
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_thunk.h25
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc31
3 files changed, 49 insertions, 49 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
index 3a23ac1d63..85f3682a5a 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
@@ -29,21 +29,51 @@ limitations under the License.
namespace xla {
namespace gpu {
-using se::dnn::AlgorithmDesc;
+ConvolutionThunk::ConvolutionThunk(
+ const HloCustomCallInstruction* cudnn_call,
+ std::vector<BufferAllocation::Slice> operand_slices,
+ BufferAllocation::Slice result_slice, BufferAllocation::Slice scratch_slice,
+ BufferAllocation::Slice tuple_result_slice)
+ : Thunk(Kind::kConvolution, cudnn_call),
+ cudnn_call_(cudnn_call),
+ operand_buffers_(std::move(operand_slices)),
+ result_buffer_(result_slice),
+ scratch_buffer_(scratch_slice),
+ tuple_result_buffer_(tuple_result_slice) {}
Status ConvolutionThunk::ExecuteOnStream(
const BufferAllocations& buffer_allocations, se::Stream* stream,
HloExecutionProfiler* profiler) {
CudnnConvParams params;
+ TF_RETURN_IF_ERROR(PopulateCudnnConvParams(cudnn_call_, &params));
+
+ switch (params.kind) {
+ case CudnnConvKind::kForward:
+ params.input_buf =
+ buffer_allocations.GetDeviceAddress(operand_buffers_[0]);
+ params.filter_buf =
+ buffer_allocations.GetDeviceAddress(operand_buffers_[1]);
+ params.output_buf = buffer_allocations.GetDeviceAddress(result_buffer_);
+ break;
+ case CudnnConvKind::kBackwardInput:
+ params.input_buf = buffer_allocations.GetDeviceAddress(result_buffer_);
+ params.filter_buf =
+ buffer_allocations.GetDeviceAddress(operand_buffers_[1]);
+ params.output_buf =
+ buffer_allocations.GetDeviceAddress(operand_buffers_[0]);
+ break;
+ case CudnnConvKind::kBackwardFilter:
+ params.input_buf =
+ buffer_allocations.GetDeviceAddress(operand_buffers_[0]);
+ params.filter_buf = buffer_allocations.GetDeviceAddress(result_buffer_);
+ params.output_buf =
+ buffer_allocations.GetDeviceAddress(operand_buffers_[1]);
+ break;
+ }
- params.input_buf = buffer_allocations.GetDeviceAddress(input_buffer_);
- params.filter_buf = buffer_allocations.GetDeviceAddress(filter_buffer_);
- params.output_buf = buffer_allocations.GetDeviceAddress(output_buffer_);
se::DeviceMemoryBase scratch =
buffer_allocations.GetDeviceAddress(scratch_buffer_);
- TF_RETURN_IF_ERROR(PopulateCudnnConvParams(cudnn_call_, &params));
-
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
TF_RETURN_IF_ERROR(RunCudnnConvolution(params, scratch, stream));
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
index d7d1f91fba..f53bc54198 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
@@ -42,24 +42,12 @@ class ConvolutionThunk : public Thunk {
// Constructs a thunk for launching a DNN convolution. When run, it will
// write a tuple (result, scratch_memory) into `tuple_result_buffer`.
//
- // Note that "output" here doesn't refer to the output from running this
- // thunk, but rather to the "output" of a hypothetical forward convolution
- // that corresponds to this input+filter+output triple. That is, the result
- // generated by this thunk is "output" for forward convs, "input" for
- // backward-input convs, and "filter" for backward-filter convs.
+ // operand_slices should be in the same order as cudnn_call->operands().
ConvolutionThunk(const HloCustomCallInstruction* cudnn_call,
- BufferAllocation::Slice input_slice,
- BufferAllocation::Slice filter_slice,
- BufferAllocation::Slice output_slice,
+ std::vector<BufferAllocation::Slice> operand_slices,
+ BufferAllocation::Slice result_slice,
BufferAllocation::Slice scratch_slice,
- BufferAllocation::Slice tuple_result_slice)
- : Thunk(Kind::kConvolution, cudnn_call),
- cudnn_call_(cudnn_call),
- input_buffer_(std::move(input_slice)),
- filter_buffer_(std::move(filter_slice)),
- output_buffer_(std::move(output_slice)),
- scratch_buffer_(std::move(scratch_slice)),
- tuple_result_buffer_(std::move(tuple_result_slice)) {}
+ BufferAllocation::Slice tuple_result_slice);
ConvolutionThunk(const ConvolutionThunk&) = delete;
ConvolutionThunk& operator=(const ConvolutionThunk&) = delete;
@@ -71,9 +59,8 @@ class ConvolutionThunk : public Thunk {
private:
const HloCustomCallInstruction* cudnn_call_;
- BufferAllocation::Slice input_buffer_;
- BufferAllocation::Slice filter_buffer_;
- BufferAllocation::Slice output_buffer_;
+ std::vector<BufferAllocation::Slice> operand_buffers_;
+ BufferAllocation::Slice result_buffer_;
BufferAllocation::Slice scratch_buffer_;
BufferAllocation::Slice tuple_result_buffer_;
};
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index b669881026..c792dd2ddb 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -465,35 +465,18 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
if (IsCustomCallToDnnConvolution(*custom_call)) {
const auto& assn = ir_emitter_context_->buffer_assignment();
- auto lhs_slice = GetAllocationSlice(*custom_call->operand(0));
- auto rhs_slice = GetAllocationSlice(*custom_call->operand(1));
+ std::vector<BufferAllocation::Slice> operand_slices;
+ operand_slices.reserve(custom_call->operand_count());
+ for (const auto* operand : custom_call->operands()) {
+ operand_slices.push_back(GetAllocationSlice(*operand));
+ }
auto tuple_result_slice = GetAllocationSlice(*custom_call);
auto conv_result_slice = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie();
auto scratch_slice = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie();
- const auto& target = custom_call->custom_call_target();
- BufferAllocation::Slice input_slice, filter_slice, output_slice;
-
- if (target == kCudnnConvForwardCallTarget) {
- input_slice = lhs_slice;
- filter_slice = rhs_slice;
- output_slice = conv_result_slice;
- } else if (target == kCudnnConvBackwardInputCallTarget) {
- input_slice = conv_result_slice;
- filter_slice = rhs_slice;
- output_slice = lhs_slice;
- } else if (target == kCudnnConvBackwardFilterCallTarget) {
- input_slice = lhs_slice;
- filter_slice = conv_result_slice;
- output_slice = rhs_slice;
- } else {
- LOG(FATAL) << "Unexpected custom call target: "
- << custom_call->custom_call_target();
- }
-
thunk_sequence_->emplace_back(absl::make_unique<ConvolutionThunk>(
- Cast<HloCustomCallInstruction>(custom_call), input_slice, filter_slice,
- output_slice, scratch_slice, tuple_result_slice));
+ Cast<HloCustomCallInstruction>(custom_call), std::move(operand_slices),
+ conv_result_slice, scratch_slice, tuple_result_slice));
return Status::OK();
}