aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Adrian Kuegel <akuegel@google.com>2018-07-18 03:10:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-18 03:13:13 -0700
commitb74f7b71fad773dd90c8f48b66bc82fb07eb9bc0 (patch)
tree712a3021c27a7bd044b7e8237ec1b281f20680ff
parent3a576d3a2847cce68c4c4565f8a1124d7421ca3e (diff)
Implement BitonicSort for GPU.
This is a first version, several things are still missing: - Support for key/value sorting. - Support for other types than F32, S32 and U32. - Parallelization of the inner loop. PiperOrigin-RevId: 205052657
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.cc172
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.h7
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc29
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h1
4 files changed, 206 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
index fe83d017f4..a08b72e3af 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
@@ -44,6 +44,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/window_util.h"
+#include "tensorflow/core/lib/core/bits.h"
#include "tensorflow/core/lib/core/errors.h"
namespace xla {
@@ -123,9 +124,136 @@ Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) {
return Status::OK();
}
-Status IrEmitter::HandleSort(HloInstruction*) {
- // TODO(b/26783907): Implement sort on GPU.
- return Unimplemented("sort");
+Status IrEmitter::HandleSort(HloInstruction* sort) {
+ auto keys = sort->operand(0);
+ auto values = sort->operand_count() > 1 ? sort->operand(1) : nullptr;
+ if (values != nullptr) {
+ // TODO(b/26783907): Also sort the values by their corresponding key.
+ return Unimplemented("Key/Value Sort is not implemented on GPU");
+ }
+ int dimension_to_sort = sort->dimensions(0);
+ const llvm_ir::IrArray& keys_array = GetIrArray(*keys, *sort);
+ const llvm_ir::IrArray& target_array = GetIrArray(*sort, *sort);
+
+ const Shape& keys_shape = keys->shape();
+
+ // TODO(b/26783907): This case can probably be avoided with the Algebraic
+ // Simplifier.
+ if (ShapeUtil::IsScalar(keys_shape)) {
+ return Status::OK();
+ }
+
+ // Create loop nests which loop through the operand dimensions. The sort
+ // dimension is handled in three separate innermost loops which perform the
+ // sorting.
+ llvm_ir::ForLoopNest loop_nest(IrName(sort), &ir_builder_);
+ llvm_ir::IrArray::Index keys_index = EmitOperandArrayLoopNest(
+ keys_array, dimension_to_sort, "keys", &loop_nest);
+
+ // 'compare_keys_index' is the index of the element that 'keys_index' should
+ // be compared to.
+ llvm_ir::IrArray::Index compare_keys_index(keys_index.GetType());
+ for (size_t dimension = 0; dimension < keys_index.size(); ++dimension) {
+ if (dimension != dimension_to_sort) {
+ compare_keys_index.push_back(keys_index[dimension]);
+ } else {
+ compare_keys_index.push_back(nullptr);
+ }
+ }
+
+ // Create the sorting loops which do the sorting.
+ int64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort);
+ std::unique_ptr<llvm_ir::ForLoop> stages_loop = loop_nest.AddLoop(
+ /*start_index=*/0,
+ /*end_index=*/
+ tensorflow::Log2Ceiling64(dimension_to_sort_bound),
+ /*suffix=*/"sort_stages");
+ std::unique_ptr<llvm_ir::ForLoop> mask_loop = loop_nest.AddLoop(
+ /*suffix=*/"mask",
+ /*start_index=*/keys_index.GetConstantWithIndexType(0),
+ /*end_index=*/stages_loop->GetIndVarValue());
+ std::unique_ptr<llvm_ir::ForLoop> compare_loop = loop_nest.AddLoop(
+ /*start_index=*/0,
+ /*end_index=*/dimension_to_sort_bound,
+ /*suffix=*/"compare");
+
+ // Naive C++ code for the inner loops (without parallelization):
+ //
+ // for (int64 stage = 0; stage < Log2Ceiling(dimension_to_sort_bound);
+ // ++stage) {
+ // int64 first_xor_mask = (1LL << (stage + 1)) - 1;
+ // for (int64 i = 0; i < dimension_to_sort_bound; ++i) {
+ // int64 j = i ^ first_xor_mask;
+ // if (i < j && j < dimension_to_sort_bound) {
+ // int64 min_key = std::min(keys[i], keys[j]);
+ // keys[j] = std::max(keys[i], keys[j]);
+ // keys[i] = min_key;
+ // }
+ // }
+ // for (int64 mask = 0; mask < stage; ++mask) {
+ // int64 later_xor_mask = (1LL << (stage - (mask + 1));
+ // for (int64 i = 0; i < dimension_to_sort_bound; ++i) {
+ // int64 j = i ^ later_xor_mask;
+ // if (i < j && j < dimension_to_sort_bound) {
+ // int64 min_key = std::min(keys[i], keys[j]);
+ // keys[j] = std::max(keys[i], keys[j]);
+ // keys[i] = min_key;
+ // }
+ // }
+ // }
+ // }
+ //
+ // This follows the algorithm described on Wikipedia:
+ // https://en.wikipedia.org/wiki/Bitonic_sorter
+
+ SetToFirstInsertPoint(stages_loop->GetBodyBasicBlock(), &ir_builder_);
+ // The first xor mask of a stage is 2^(stage + 1) - 1.
+ auto first_xor_mask = ir_builder_.CreateSub(
+ ir_builder_.CreateShl(
+ keys_index.GetConstantWithIndexType(1),
+ ir_builder_.CreateAdd(stages_loop->GetIndVarValue(),
+ keys_index.GetConstantWithIndexType(1))),
+ keys_index.GetConstantWithIndexType(1));
+ std::unique_ptr<llvm_ir::ForLoop> first_compare_loop =
+ llvm_ir::ForLoop::EmitForLoop(
+ /*prefix=*/"first_compare",
+ /*start_index=*/keys_index.GetConstantWithIndexType(0),
+ /*end_index=*/
+ keys_index.GetConstantWithIndexType(
+ keys_shape.dimensions(dimension_to_sort)),
+ /*step=*/keys_index.GetConstantWithIndexType(1),
+ /*ir_builder=*/&ir_builder_);
+
+ SetToFirstInsertPoint(first_compare_loop->GetBodyBasicBlock(), &ir_builder_);
+ // 'first_compare_loop' iterates through the 'dimension_to_sort'.
+ keys_index[dimension_to_sort] = first_compare_loop->GetIndVarValue();
+ compare_keys_index[dimension_to_sort] = ir_builder_.CreateXor(
+ first_compare_loop->GetIndVarValue(), first_xor_mask);
+ EmitCompareLoop(dimension_to_sort, keys_index, compare_keys_index,
+ target_array);
+
+ SetToFirstInsertPoint(compare_loop->GetPreheaderBasicBlock(), &ir_builder_);
+ // The later masks of a stage are 2^(stage - (mask_loop_ind_var + 1)).
+ auto later_xor_mask = ir_builder_.CreateShl(
+ keys_index.GetConstantWithIndexType(1),
+ ir_builder_.CreateSub(
+ stages_loop->GetIndVarValue(),
+ ir_builder_.CreateAdd(mask_loop->GetIndVarValue(),
+ keys_index.GetConstantWithIndexType(1))));
+
+ SetToFirstInsertPoint(compare_loop->GetBodyBasicBlock(), &ir_builder_);
+ // 'compare_loop' iterates through the 'dimension_to_sort'.
+ keys_index[dimension_to_sort] = compare_loop->GetIndVarValue();
+ compare_keys_index[dimension_to_sort] =
+ ir_builder_.CreateXor(compare_loop->GetIndVarValue(), later_xor_mask);
+ EmitCompareLoop(dimension_to_sort, keys_index, compare_keys_index,
+ target_array);
+
+ // Set the IR builder insert point to the exit basic block of the outer most
+ // loop. This ensures later instructions are inserted after this loop nest.
+ ir_builder_.SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock());
+
+ return Status::OK();
}
Status IrEmitter::HandleSend(HloInstruction*) {
@@ -399,6 +527,44 @@ Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation,
return Status::OK();
}
+void IrEmitter::EmitCompareLoop(
+ int64 dimension_to_sort, const llvm_ir::IrArray::Index& keys_index,
+ const llvm_ir::IrArray::Index& compare_keys_index,
+ const llvm_ir::IrArray& keys_array) {
+ // TODO(b/26783907): parallelize this loop.
+
+ // if (is_smaller_index &&
+ // compare_keys[dimension_to_sort] < dimension_to_sort_bound)
+ llvm::Value* is_smaller_index = ir_builder_.CreateICmpSLT(
+ keys_index[dimension_to_sort], compare_keys_index[dimension_to_sort]);
+ int64 dimension_to_sort_bound =
+ keys_array.GetShape().dimensions(dimension_to_sort);
+ auto if_data = llvm_ir::EmitIfThenElse(
+ ir_builder_.CreateAnd(
+ is_smaller_index,
+ ir_builder_.CreateICmpSLT(
+ compare_keys_index[dimension_to_sort],
+ keys_index.GetConstantWithIndexType(dimension_to_sort_bound))),
+ "smaller_comparison_index", &ir_builder_, /*emit_else=*/false);
+ SetToFirstInsertPoint(if_data.true_block, &ir_builder_);
+ auto key1 = keys_array.EmitReadArrayElement(keys_index, &ir_builder_);
+ auto key2 = keys_array.EmitReadArrayElement(compare_keys_index, &ir_builder_);
+ auto key_type = keys_array.GetShape().element_type();
+ auto comparison =
+ primitive_util::IsFloatingPointType(key_type)
+ // TODO(b/26783907): Figure out how to handle NaNs.
+ ? ir_builder_.CreateFCmp(llvm::FCmpInst::FCMP_ULT, key1, key2)
+ : ir_builder_.CreateICmp(
+ primitive_util::IsSignedIntegralType(key_type)
+ ? llvm::ICmpInst::ICMP_SLT
+ : llvm::ICmpInst::ICMP_ULT,
+ key1, key2);
+ auto min_key = ir_builder_.CreateSelect(comparison, key1, key2);
+ auto max_key = ir_builder_.CreateSelect(comparison, key2, key1);
+ keys_array.EmitWriteArrayElement(keys_index, min_key, &ir_builder_);
+ keys_array.EmitWriteArrayElement(compare_keys_index, max_key, &ir_builder_);
+}
+
Status IrEmitter::EmitAtomicOperationForNestedComputation(
const HloComputation& computation, llvm::Value* output_address,
llvm::Value* source_address) {
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
index d2dd335f10..e9ad4a752b 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
@@ -198,6 +198,13 @@ class IrEmitter : public DfsHloVisitorWithDefault {
llvm::Value* output_address,
llvm::Value* source_address);
+ // A helper method for HandleSort(). It adds the inner comparison loop where
+ // we compare elements pointed to by 'keys_index' and 'compare_keys_index'.
+ void EmitCompareLoop(int64 dimension_to_sort,
+ const llvm_ir::IrArray::Index& keys_index,
+ const llvm_ir::IrArray::Index& compare_keys_index,
+ const llvm_ir::IrArray& keys_array);
+
StatusOr<llvm::Value*> ComputeNestedElement(
const HloComputation& computation,
tensorflow::gtl::ArraySlice<llvm::Value*> parameter_elements);
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index f2597da4b9..70a227ca4a 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -2046,6 +2046,35 @@ Status IrEmitterUnnested::HandleSelect(HloInstruction* select) {
return IrEmitter::HandleSelect(select);
}
+Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
+ std::vector<std::unique_ptr<Thunk>> thunks;
+ auto values = sort->operand_count() > 1 ? sort->operand(1) : nullptr;
+ if (values != nullptr) {
+ // TODO(b/26783907): Also sort the values by their corresponding key.
+ return Unimplemented("Key/Value Sort is not implemented on GPU");
+ }
+
+ // First copy the operand to the output, so that we can sort in-place.
+ // TODO(b/26783907): Share buffer of output and operand when it is possible.
+ if (sort->operand(0)->IsConstant()) {
+ thunks.push_back(MakeUnique<HostToDeviceCopyThunk>(
+ /*source_address=*/sort->operand(0)->literal().untyped_data(),
+ /*destination_buffer=*/GetAllocationSlice(*sort),
+ /*mem_size=*/ShapeUtil::ByteSizeOf(sort->shape()), sort));
+ } else {
+ thunks.push_back(MakeUnique<DeviceToDeviceCopyThunk>(
+ /*source_address=*/GetAllocationSlice(*sort->operand(0)),
+ /*destination_buffer=*/GetAllocationSlice(*sort),
+ /*mem_size=*/ShapeUtil::ByteSizeOf(sort->shape()), sort));
+ }
+
+ thunks.push_back(
+ BuildKernelThunk(sort, /*implements_whole_instruction=*/false));
+ thunk_sequence_->emplace_back(
+ MakeUnique<SequentialThunk>(std::move(thunks), sort));
+ return IrEmitter::HandleSort(sort);
+}
+
Status IrEmitterUnnested::HandleTupleSelect(HloInstruction* tuple_select) {
thunk_sequence_->push_back(
BuildKernelThunk(tuple_select, /*implements_whole_instruction=*/true));
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
index 59547c16d7..616d8a2206 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
@@ -77,6 +77,7 @@ class IrEmitterUnnested : public IrEmitter {
Status HandleOutfeed(HloInstruction* outfeed) override;
Status HandleRng(HloInstruction* random) override;
Status HandleSelect(HloInstruction* select) override;
+ Status HandleSort(HloInstruction* sort) override;
Status HandleTupleSelect(HloInstruction* tuple_select) override;
Status HandleCrossReplicaSum(HloInstruction* crs) override;
Status HandleAfterAll(HloInstruction* gen_token) override;