diff options
author | Adrian Kuegel <akuegel@google.com> | 2018-07-18 03:10:36 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-18 03:13:13 -0700 |
commit | b74f7b71fad773dd90c8f48b66bc82fb07eb9bc0 (patch) | |
tree | 712a3021c27a7bd044b7e8237ec1b281f20680ff /tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc | |
parent | 3a576d3a2847cce68c4c4565f8a1124d7421ca3e (diff) |
Implement BitonicSort for GPU.
This is a first version, several things are still missing:
- Support for key/value sorting.
- Support for other types than F32, S32 and U32.
- Parallelization of the inner loop.
PiperOrigin-RevId: 205052657
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc | 29 |
1 files changed, 29 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index f2597da4b9..70a227ca4a 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -2046,6 +2046,35 @@ Status IrEmitterUnnested::HandleSelect(HloInstruction* select) { return IrEmitter::HandleSelect(select); } +Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { + std::vector<std::unique_ptr<Thunk>> thunks; + auto values = sort->operand_count() > 1 ? sort->operand(1) : nullptr; + if (values != nullptr) { + // TODO(b/26783907): Also sort the values by their corresponding key. + return Unimplemented("Key/Value Sort is not implemented on GPU"); + } + + // First copy the operand to the output, so that we can sort in-place. + // TODO(b/26783907): Share buffer of output and operand when it is possible. + if (sort->operand(0)->IsConstant()) { + thunks.push_back(MakeUnique<HostToDeviceCopyThunk>( + /*source_address=*/sort->operand(0)->literal().untyped_data(), + /*destination_buffer=*/GetAllocationSlice(*sort), + /*mem_size=*/ShapeUtil::ByteSizeOf(sort->shape()), sort)); + } else { + thunks.push_back(MakeUnique<DeviceToDeviceCopyThunk>( + /*source_address=*/GetAllocationSlice(*sort->operand(0)), + /*destination_buffer=*/GetAllocationSlice(*sort), + /*mem_size=*/ShapeUtil::ByteSizeOf(sort->shape()), sort)); + } + + thunks.push_back( + BuildKernelThunk(sort, /*implements_whole_instruction=*/false)); + thunk_sequence_->emplace_back( + MakeUnique<SequentialThunk>(std::move(thunks), sort)); + return IrEmitter::HandleSort(sort); +} + Status IrEmitterUnnested::HandleTupleSelect(HloInstruction* tuple_select) { thunk_sequence_->push_back( BuildKernelThunk(tuple_select, /*implements_whole_instruction=*/true)); |