diff options
author | Tim Shen <timshen@google.com> | 2018-09-19 12:40:56 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-19 12:44:07 -0700 |
commit | ff11877b101fe9c19021e8d7b43841031eb71cc3 (patch) | |
tree | 94ab2b04488b91b3b62dc3976a255345d0c9ae02 /tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc | |
parent | 0b97c406413fbf71897b28461b35470f3f14fd7e (diff) |
Simplify ir_emitter_unnested so that it doesn't take a look at conv
custom call and try to understand what's inside. convolution_thunk does
it anyway.
PiperOrigin-RevId: 213676051
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc | 31 |
1 files changed, 7 insertions, 24 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index b669881026..c792dd2ddb 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -465,35 +465,18 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { if (IsCustomCallToDnnConvolution(*custom_call)) { const auto& assn = ir_emitter_context_->buffer_assignment(); - auto lhs_slice = GetAllocationSlice(*custom_call->operand(0)); - auto rhs_slice = GetAllocationSlice(*custom_call->operand(1)); + std::vector<BufferAllocation::Slice> operand_slices; + operand_slices.reserve(custom_call->operand_count()); + for (const auto* operand : custom_call->operands()) { + operand_slices.push_back(GetAllocationSlice(*operand)); + } auto tuple_result_slice = GetAllocationSlice(*custom_call); auto conv_result_slice = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie(); auto scratch_slice = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); - const auto& target = custom_call->custom_call_target(); - BufferAllocation::Slice input_slice, filter_slice, output_slice; - - if (target == kCudnnConvForwardCallTarget) { - input_slice = lhs_slice; - filter_slice = rhs_slice; - output_slice = conv_result_slice; - } else if (target == kCudnnConvBackwardInputCallTarget) { - input_slice = conv_result_slice; - filter_slice = rhs_slice; - output_slice = lhs_slice; - } else if (target == kCudnnConvBackwardFilterCallTarget) { - input_slice = lhs_slice; - filter_slice = conv_result_slice; - output_slice = rhs_slice; - } else { - LOG(FATAL) << "Unexpected custom call target: " - << custom_call->custom_call_target(); - } - thunk_sequence_->emplace_back(absl::make_unique<ConvolutionThunk>( - Cast<HloCustomCallInstruction>(custom_call), input_slice, filter_slice, - output_slice, scratch_slice, tuple_result_slice)); + Cast<HloCustomCallInstruction>(custom_call), std::move(operand_slices), + conv_result_slice, scratch_slice, tuple_result_slice)); return Status::OK(); } |