diff options
author | Adrian Kuegel <akuegel@google.com> | 2018-07-18 03:10:36 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-18 03:13:13 -0700 |
commit | b74f7b71fad773dd90c8f48b66bc82fb07eb9bc0 (patch) | |
tree | 712a3021c27a7bd044b7e8237ec1b281f20680ff | |
parent | 3a576d3a2847cce68c4c4565f8a1124d7421ca3e (diff) |
Implement BitonicSort for GPU.
This is a first version, several things are still missing:
- Support for key/value sorting.
- Support for other types than F32, S32 and U32.
- Parallelization of the inner loop.
PiperOrigin-RevId: 205052657
4 files changed, 206 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index fe83d017f4..a08b72e3af 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -44,6 +44,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" +#include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { @@ -123,9 +124,136 @@ Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) { return Status::OK(); } -Status IrEmitter::HandleSort(HloInstruction*) { - // TODO(b/26783907): Implement sort on GPU. - return Unimplemented("sort"); +Status IrEmitter::HandleSort(HloInstruction* sort) { + auto keys = sort->operand(0); + auto values = sort->operand_count() > 1 ? sort->operand(1) : nullptr; + if (values != nullptr) { + // TODO(b/26783907): Also sort the values by their corresponding key. + return Unimplemented("Key/Value Sort is not implemented on GPU"); + } + int dimension_to_sort = sort->dimensions(0); + const llvm_ir::IrArray& keys_array = GetIrArray(*keys, *sort); + const llvm_ir::IrArray& target_array = GetIrArray(*sort, *sort); + + const Shape& keys_shape = keys->shape(); + + // TODO(b/26783907): This case can probably be avoided with the Algebraic + // Simplifier. + if (ShapeUtil::IsScalar(keys_shape)) { + return Status::OK(); + } + + // Create loop nests which loop through the operand dimensions. The sort + // dimension is handled in three separate innermost loops which perform the + // sorting. + llvm_ir::ForLoopNest loop_nest(IrName(sort), &ir_builder_); + llvm_ir::IrArray::Index keys_index = EmitOperandArrayLoopNest( + keys_array, dimension_to_sort, "keys", &loop_nest); + + // 'compare_keys_index' is the index of the element that 'keys_index' should + // be compared to. + llvm_ir::IrArray::Index compare_keys_index(keys_index.GetType()); + for (size_t dimension = 0; dimension < keys_index.size(); ++dimension) { + if (dimension != dimension_to_sort) { + compare_keys_index.push_back(keys_index[dimension]); + } else { + compare_keys_index.push_back(nullptr); + } + } + + // Create the sorting loops which do the sorting. + int64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort); + std::unique_ptr<llvm_ir::ForLoop> stages_loop = loop_nest.AddLoop( + /*start_index=*/0, + /*end_index=*/ + tensorflow::Log2Ceiling64(dimension_to_sort_bound), + /*suffix=*/"sort_stages"); + std::unique_ptr<llvm_ir::ForLoop> mask_loop = loop_nest.AddLoop( + /*suffix=*/"mask", + /*start_index=*/keys_index.GetConstantWithIndexType(0), + /*end_index=*/stages_loop->GetIndVarValue()); + std::unique_ptr<llvm_ir::ForLoop> compare_loop = loop_nest.AddLoop( + /*start_index=*/0, + /*end_index=*/dimension_to_sort_bound, + /*suffix=*/"compare"); + + // Naive C++ code for the inner loops (without parallelization): + // + // for (int64 stage = 0; stage < Log2Ceiling(dimension_to_sort_bound); + // ++stage) { + // int64 first_xor_mask = (1LL << (stage + 1)) - 1; + // for (int64 i = 0; i < dimension_to_sort_bound; ++i) { + // int64 j = i ^ first_xor_mask; + // if (i < j && j < dimension_to_sort_bound) { + // int64 min_key = std::min(keys[i], keys[j]); + // keys[j] = std::max(keys[i], keys[j]); + // keys[i] = min_key; + // } + // } + // for (int64 mask = 0; mask < stage; ++mask) { + // int64 later_xor_mask = (1LL << (stage - (mask + 1)); + // for (int64 i = 0; i < dimension_to_sort_bound; ++i) { + // int64 j = i ^ later_xor_mask; + // if (i < j && j < dimension_to_sort_bound) { + // int64 min_key = std::min(keys[i], keys[j]); + // keys[j] = std::max(keys[i], keys[j]); + // keys[i] = min_key; + // } + // } + // } + // } + // + // This follows the algorithm described on Wikipedia: + // https://en.wikipedia.org/wiki/Bitonic_sorter + + SetToFirstInsertPoint(stages_loop->GetBodyBasicBlock(), &ir_builder_); + // The first xor mask of a stage is 2^(stage + 1) - 1. + auto first_xor_mask = ir_builder_.CreateSub( + ir_builder_.CreateShl( + keys_index.GetConstantWithIndexType(1), + ir_builder_.CreateAdd(stages_loop->GetIndVarValue(), + keys_index.GetConstantWithIndexType(1))), + keys_index.GetConstantWithIndexType(1)); + std::unique_ptr<llvm_ir::ForLoop> first_compare_loop = + llvm_ir::ForLoop::EmitForLoop( + /*prefix=*/"first_compare", + /*start_index=*/keys_index.GetConstantWithIndexType(0), + /*end_index=*/ + keys_index.GetConstantWithIndexType( + keys_shape.dimensions(dimension_to_sort)), + /*step=*/keys_index.GetConstantWithIndexType(1), + /*ir_builder=*/&ir_builder_); + + SetToFirstInsertPoint(first_compare_loop->GetBodyBasicBlock(), &ir_builder_); + // 'first_compare_loop' iterates through the 'dimension_to_sort'. + keys_index[dimension_to_sort] = first_compare_loop->GetIndVarValue(); + compare_keys_index[dimension_to_sort] = ir_builder_.CreateXor( + first_compare_loop->GetIndVarValue(), first_xor_mask); + EmitCompareLoop(dimension_to_sort, keys_index, compare_keys_index, + target_array); + + SetToFirstInsertPoint(compare_loop->GetPreheaderBasicBlock(), &ir_builder_); + // The later masks of a stage are 2^(stage - (mask_loop_ind_var + 1)). + auto later_xor_mask = ir_builder_.CreateShl( + keys_index.GetConstantWithIndexType(1), + ir_builder_.CreateSub( + stages_loop->GetIndVarValue(), + ir_builder_.CreateAdd(mask_loop->GetIndVarValue(), + keys_index.GetConstantWithIndexType(1)))); + + SetToFirstInsertPoint(compare_loop->GetBodyBasicBlock(), &ir_builder_); + // 'compare_loop' iterates through the 'dimension_to_sort'. + keys_index[dimension_to_sort] = compare_loop->GetIndVarValue(); + compare_keys_index[dimension_to_sort] = + ir_builder_.CreateXor(compare_loop->GetIndVarValue(), later_xor_mask); + EmitCompareLoop(dimension_to_sort, keys_index, compare_keys_index, + target_array); + + // Set the IR builder insert point to the exit basic block of the outer most + // loop. This ensures later instructions are inserted after this loop nest. + ir_builder_.SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock()); + + return Status::OK(); } Status IrEmitter::HandleSend(HloInstruction*) { @@ -399,6 +527,44 @@ Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation, return Status::OK(); } +void IrEmitter::EmitCompareLoop( + int64 dimension_to_sort, const llvm_ir::IrArray::Index& keys_index, + const llvm_ir::IrArray::Index& compare_keys_index, + const llvm_ir::IrArray& keys_array) { + // TODO(b/26783907): parallelize this loop. + + // if (is_smaller_index && + // compare_keys[dimension_to_sort] < dimension_to_sort_bound) + llvm::Value* is_smaller_index = ir_builder_.CreateICmpSLT( + keys_index[dimension_to_sort], compare_keys_index[dimension_to_sort]); + int64 dimension_to_sort_bound = + keys_array.GetShape().dimensions(dimension_to_sort); + auto if_data = llvm_ir::EmitIfThenElse( + ir_builder_.CreateAnd( + is_smaller_index, + ir_builder_.CreateICmpSLT( + compare_keys_index[dimension_to_sort], + keys_index.GetConstantWithIndexType(dimension_to_sort_bound))), + "smaller_comparison_index", &ir_builder_, /*emit_else=*/false); + SetToFirstInsertPoint(if_data.true_block, &ir_builder_); + auto key1 = keys_array.EmitReadArrayElement(keys_index, &ir_builder_); + auto key2 = keys_array.EmitReadArrayElement(compare_keys_index, &ir_builder_); + auto key_type = keys_array.GetShape().element_type(); + auto comparison = + primitive_util::IsFloatingPointType(key_type) + // TODO(b/26783907): Figure out how to handle NaNs. + ? ir_builder_.CreateFCmp(llvm::FCmpInst::FCMP_ULT, key1, key2) + : ir_builder_.CreateICmp( + primitive_util::IsSignedIntegralType(key_type) + ? llvm::ICmpInst::ICMP_SLT + : llvm::ICmpInst::ICMP_ULT, + key1, key2); + auto min_key = ir_builder_.CreateSelect(comparison, key1, key2); + auto max_key = ir_builder_.CreateSelect(comparison, key2, key1); + keys_array.EmitWriteArrayElement(keys_index, min_key, &ir_builder_); + keys_array.EmitWriteArrayElement(compare_keys_index, max_key, &ir_builder_); +} + Status IrEmitter::EmitAtomicOperationForNestedComputation( const HloComputation& computation, llvm::Value* output_address, llvm::Value* source_address) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index d2dd335f10..e9ad4a752b 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -198,6 +198,13 @@ class IrEmitter : public DfsHloVisitorWithDefault { llvm::Value* output_address, llvm::Value* source_address); + // A helper method for HandleSort(). It adds the inner comparison loop where + // we compare elements pointed to by 'keys_index' and 'compare_keys_index'. + void EmitCompareLoop(int64 dimension_to_sort, + const llvm_ir::IrArray::Index& keys_index, + const llvm_ir::IrArray::Index& compare_keys_index, + const llvm_ir::IrArray& keys_array); + StatusOr<llvm::Value*> ComputeNestedElement( const HloComputation& computation, tensorflow::gtl::ArraySlice<llvm::Value*> parameter_elements); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index f2597da4b9..70a227ca4a 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -2046,6 +2046,35 @@ Status IrEmitterUnnested::HandleSelect(HloInstruction* select) { return IrEmitter::HandleSelect(select); } +Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { + std::vector<std::unique_ptr<Thunk>> thunks; + auto values = sort->operand_count() > 1 ? sort->operand(1) : nullptr; + if (values != nullptr) { + // TODO(b/26783907): Also sort the values by their corresponding key. + return Unimplemented("Key/Value Sort is not implemented on GPU"); + } + + // First copy the operand to the output, so that we can sort in-place. + // TODO(b/26783907): Share buffer of output and operand when it is possible. + if (sort->operand(0)->IsConstant()) { + thunks.push_back(MakeUnique<HostToDeviceCopyThunk>( + /*source_address=*/sort->operand(0)->literal().untyped_data(), + /*destination_buffer=*/GetAllocationSlice(*sort), + /*mem_size=*/ShapeUtil::ByteSizeOf(sort->shape()), sort)); + } else { + thunks.push_back(MakeUnique<DeviceToDeviceCopyThunk>( + /*source_address=*/GetAllocationSlice(*sort->operand(0)), + /*destination_buffer=*/GetAllocationSlice(*sort), + /*mem_size=*/ShapeUtil::ByteSizeOf(sort->shape()), sort)); + } + + thunks.push_back( + BuildKernelThunk(sort, /*implements_whole_instruction=*/false)); + thunk_sequence_->emplace_back( + MakeUnique<SequentialThunk>(std::move(thunks), sort)); + return IrEmitter::HandleSort(sort); +} + Status IrEmitterUnnested::HandleTupleSelect(HloInstruction* tuple_select) { thunk_sequence_->push_back( BuildKernelThunk(tuple_select, /*implements_whole_instruction=*/true)); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 59547c16d7..616d8a2206 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -77,6 +77,7 @@ class IrEmitterUnnested : public IrEmitter { Status HandleOutfeed(HloInstruction* outfeed) override; Status HandleRng(HloInstruction* random) override; Status HandleSelect(HloInstruction* select) override; + Status HandleSort(HloInstruction* sort) override; Status HandleTupleSelect(HloInstruction* tuple_select) override; Status HandleCrossReplicaSum(HloInstruction* crs) override; Status HandleAfterAll(HloInstruction* gen_token) override; |