diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc | 61 |
1 files changed, 15 insertions, 46 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index f91cc00d71..b669881026 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -61,6 +61,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/gpu/tuple_thunk.h" #include "tensorflow/compiler/xla/service/gpu/while_thunk.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -464,67 +465,35 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { if (IsCustomCallToDnnConvolution(*custom_call)) { const auto& assn = ir_emitter_context_->buffer_assignment(); - const auto& lhs_shape = custom_call->operand(0)->shape(); - const auto& rhs_shape = custom_call->operand(1)->shape(); - const auto& conv_result_shape = custom_call->shape().tuple_shapes(0); auto lhs_slice = GetAllocationSlice(*custom_call->operand(0)); auto rhs_slice = GetAllocationSlice(*custom_call->operand(1)); auto tuple_result_slice = GetAllocationSlice(*custom_call); auto conv_result_slice = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie(); auto scratch_slice = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); - TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config, - custom_call->backend_config<CudnnConvBackendConfig>()); const auto& target = custom_call->custom_call_target(); - std::unique_ptr<ConvolutionThunk> thunk; + BufferAllocation::Slice input_slice, filter_slice, output_slice; + if (target == kCudnnConvForwardCallTarget) { - thunk = absl::make_unique<ConvolutionThunk>( - CudnnConvKind::kForward, - /*input_buffer=*/lhs_slice, - /*filter_buffer=*/rhs_slice, - /*output_buffer=*/conv_result_slice, - /*tuple_result_buffer=*/tuple_result_slice, - /*scratch_buffer=*/scratch_slice, - /*input_shape=*/lhs_shape, - /*filter_shape=*/rhs_shape, - /*output_shape=*/conv_result_shape, // - custom_call->window(), custom_call->convolution_dimension_numbers(), - custom_call->feature_group_count(), backend_config.algorithm(), - backend_config.tensor_ops_enabled(), custom_call); + input_slice = lhs_slice; + filter_slice = rhs_slice; + output_slice = conv_result_slice; } else if (target == kCudnnConvBackwardInputCallTarget) { - thunk = absl::make_unique<ConvolutionThunk>( - CudnnConvKind::kBackwardInput, - /*input_buffer=*/conv_result_slice, - /*filter_buffer=*/rhs_slice, - /*output_buffer=*/lhs_slice, - /*tuple_result_buffer=*/tuple_result_slice, - /*scratch_buffer=*/scratch_slice, - /*input_shape=*/conv_result_shape, - /*filter_shape=*/rhs_shape, - /*output_shape=*/lhs_shape, // - custom_call->window(), custom_call->convolution_dimension_numbers(), - custom_call->feature_group_count(), backend_config.algorithm(), - backend_config.tensor_ops_enabled(), custom_call); + input_slice = conv_result_slice; + filter_slice = rhs_slice; + output_slice = lhs_slice; } else if (target == kCudnnConvBackwardFilterCallTarget) { - thunk = absl::make_unique<ConvolutionThunk>( - CudnnConvKind::kBackwardFilter, - /*input_buffer=*/lhs_slice, - /*filter_buffer=*/conv_result_slice, - /*output_buffer=*/rhs_slice, - /*tuple_result_buffer=*/tuple_result_slice, - /*scratch_buffer=*/scratch_slice, - /*input_shape=*/lhs_shape, - /*filter_shape=*/conv_result_shape, - /*output_shape=*/rhs_shape, // - custom_call->window(), custom_call->convolution_dimension_numbers(), - custom_call->feature_group_count(), backend_config.algorithm(), - backend_config.tensor_ops_enabled(), custom_call); + input_slice = lhs_slice; + filter_slice = conv_result_slice; + output_slice = rhs_slice; } else { LOG(FATAL) << "Unexpected custom call target: " << custom_call->custom_call_target(); } - thunk_sequence_->emplace_back(std::move(thunk)); + thunk_sequence_->emplace_back(absl::make_unique<ConvolutionThunk>( + Cast<HloCustomCallInstruction>(custom_call), input_slice, filter_slice, + output_slice, scratch_slice, tuple_result_slice)); return Status::OK(); } |