aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc
blob: 6f261c32f4181a6c4107f7fbcf782feb4347e587 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h"

// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h"
#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/types.h"

namespace xla {
namespace llvm_ir {

namespace {
// Adds the inner comparison loop where we compare elements pointed to by
// 'keys_index' and 'compare_keys_index'.
void EmitCompareLoop(int64 dimension_to_sort,
                     const llvm_ir::IrArray::Index& keys_index,
                     const llvm_ir::IrArray::Index& compare_keys_index,
                     const llvm_ir::IrArray& keys_array, llvm::IRBuilder<>* b) {
  // TODO(b/26783907): parallelize this loop.

  // if (is_smaller_index &&
  //     compare_keys[dimension_to_sort] < dimension_to_sort_bound)
  llvm::Value* is_smaller_index = b->CreateICmpSLT(
      keys_index[dimension_to_sort], compare_keys_index[dimension_to_sort]);
  int64 dimension_to_sort_bound =
      keys_array.GetShape().dimensions(dimension_to_sort);
  auto if_data = llvm_ir::EmitIfThenElse(
      b->CreateAnd(is_smaller_index,
                   b->CreateICmpSLT(compare_keys_index[dimension_to_sort],
                                    keys_index.GetConstantWithIndexType(
                                        dimension_to_sort_bound))),
      "smaller_comparison_index", b, /*emit_else=*/false);
  SetToFirstInsertPoint(if_data.true_block, b);
  auto key1 = keys_array.EmitReadArrayElement(keys_index, b);
  auto key2 = keys_array.EmitReadArrayElement(compare_keys_index, b);
  auto key_type = keys_array.GetShape().element_type();
  auto comparison =
      primitive_util::IsFloatingPointType(key_type)
          // TODO(b/26783907): Figure out how to handle NaNs.
          ? b->CreateFCmp(llvm::FCmpInst::FCMP_ULT, key1, key2)
          : b->CreateICmp(primitive_util::IsSignedIntegralType(key_type)
                              ? llvm::ICmpInst::ICMP_SLT
                              : llvm::ICmpInst::ICMP_ULT,
                          key1, key2);
  auto min_key = b->CreateSelect(comparison, key1, key2);
  auto max_key = b->CreateSelect(comparison, key2, key1);
  keys_array.EmitWriteArrayElement(keys_index, min_key, b);
  keys_array.EmitWriteArrayElement(compare_keys_index, max_key, b);
}
}  // namespace

Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array,
                       tensorflow::StringPiece name, llvm::Value* xor_mask,
                       llvm::IRBuilder<>* b,
                       const gpu::LaunchDimensions* launch_dimensions) {
  const Shape& keys_shape = keys_array.GetShape();

  // TODO(b/26783907): This case can probably be avoided with the Algebraic
  // Simplifier.
  if (ShapeUtil::IsScalar(keys_shape)) {
    return Status::OK();
  }

  // Create loop nests which loop through the operand dimensions. The sort
  // dimension is handled in the innermost loop which performs the sorting.
  ForLoopNest loop_nest(name, b);
  IrArray::Index keys_index =
      loop_nest.EmitOperandArrayLoopNest(keys_array, dimension_to_sort, "keys");
  if (loop_nest.GetInnerLoopBodyBasicBlock() != nullptr) {
    SetToFirstInsertPoint(loop_nest.GetInnerLoopBodyBasicBlock(), b);
  }

  // 'compare_keys_index' is the index of the element that 'keys_index' should
  // be compared to.
  IrArray::Index compare_keys_index(keys_index.GetType());
  for (size_t dimension = 0; dimension < keys_index.size(); ++dimension) {
    if (dimension != dimension_to_sort) {
      compare_keys_index.push_back(keys_index[dimension]);
    } else {
      compare_keys_index.push_back(nullptr);
    }
  }

  // Naive C++ code for the inner compare loop:
  //
  // for (int64 i = 0; i < dimension_to_sort_bound; ++i) {
  //   int64 j = i ^ xor_mask;
  //   if (i < j && j < dimension_to_sort_bound) {
  //     int64 min_key = std::min(keys[i], keys[j]);
  //     keys[j] = std::max(keys[i], keys[j]);
  //     keys[i] = min_key;
  //   }
  // }
  //
  // This follows the algorithm described on Wikipedia:
  // https://en.wikipedia.org/wiki/Bitonic_sorter

  int64 dimension_to_sort_bound =
      keys_array.GetShape().dimensions(dimension_to_sort);
  Shape compare_shape = ShapeUtil::MakeShape(keys_shape.element_type(),
                                             {dimension_to_sort_bound});
  auto compare_loop_body_emitter =
      [&](const IrArray::Index& compare_index) -> Status {
    keys_index[dimension_to_sort] = compare_index[0];
    compare_keys_index[dimension_to_sort] =
        b->CreateXor(compare_index[0], xor_mask);
    EmitCompareLoop(dimension_to_sort, keys_index, compare_keys_index,
                    keys_array, b);
    return Status::OK();
  };
  if (launch_dimensions != nullptr) {
    TF_RETURN_IF_ERROR(gpu::ParallelLoopEmitter(compare_loop_body_emitter,
                                                compare_shape,
                                                *launch_dimensions, b)
                           .EmitLoop(name));
  } else {
    TF_RETURN_IF_ERROR(LoopEmitter(compare_loop_body_emitter, compare_shape, b)
                           .EmitLoop(name));
  }

  // Set the IR builder insert point to the exit basic block of the outer most
  // loop. This ensures later instructions are inserted after this loop nest.
  b->SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock());

  return Status::OK();
}

}  // namespace llvm_ir
}  // namespace xla