aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
diff options
context:
space:
mode:
authorGravatar Adrian Kuegel <akuegel@google.com>2018-07-18 03:10:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-18 03:13:13 -0700
commitb74f7b71fad773dd90c8f48b66bc82fb07eb9bc0 (patch)
tree712a3021c27a7bd044b7e8237ec1b281f20680ff /tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
parent3a576d3a2847cce68c4c4565f8a1124d7421ca3e (diff)
Implement BitonicSort for GPU.
This is a first version, several things are still missing: - Support for key/value sorting. - Support for other types than F32, S32 and U32. - Parallelization of the inner loop. PiperOrigin-RevId: 205052657
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc29
1 files changed, 29 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index f2597da4b9..70a227ca4a 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -2046,6 +2046,35 @@ Status IrEmitterUnnested::HandleSelect(HloInstruction* select) {
return IrEmitter::HandleSelect(select);
}
+Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
+ std::vector<std::unique_ptr<Thunk>> thunks;
+ auto values = sort->operand_count() > 1 ? sort->operand(1) : nullptr;
+ if (values != nullptr) {
+ // TODO(b/26783907): Also sort the values by their corresponding key.
+ return Unimplemented("Key/Value Sort is not implemented on GPU");
+ }
+
+ // First copy the operand to the output, so that we can sort in-place.
+ // TODO(b/26783907): Share buffer of output and operand when it is possible.
+ if (sort->operand(0)->IsConstant()) {
+ thunks.push_back(MakeUnique<HostToDeviceCopyThunk>(
+ /*source_address=*/sort->operand(0)->literal().untyped_data(),
+ /*destination_buffer=*/GetAllocationSlice(*sort),
+ /*mem_size=*/ShapeUtil::ByteSizeOf(sort->shape()), sort));
+ } else {
+ thunks.push_back(MakeUnique<DeviceToDeviceCopyThunk>(
+ /*source_address=*/GetAllocationSlice(*sort->operand(0)),
+ /*destination_buffer=*/GetAllocationSlice(*sort),
+ /*mem_size=*/ShapeUtil::ByteSizeOf(sort->shape()), sort));
+ }
+
+ thunks.push_back(
+ BuildKernelThunk(sort, /*implements_whole_instruction=*/false));
+ thunk_sequence_->emplace_back(
+ MakeUnique<SequentialThunk>(std::move(thunks), sort));
+ return IrEmitter::HandleSort(sort);
+}
+
Status IrEmitterUnnested::HandleTupleSelect(HloInstruction* tuple_select) {
thunk_sequence_->push_back(
BuildKernelThunk(tuple_select, /*implements_whole_instruction=*/true));