aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
diff options
context:
space:
mode:
authorGravatar Bixia Zheng <bixia@google.com>2018-06-19 23:35:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-19 23:38:38 -0700
commit4283949adca17d2fcbf49cf510fff961a572dbaf (patch)
tree41c300eb4a985be49991363fc7b1ce2bd3415bc4 /tensorflow/compiler/xla/service/elemental_ir_emitter.cc
parent7c754a6db364443c1103bd362e826fafab8f2718 (diff)
Allow the use of 32 bit integer type for loop index and tensor element index.
The GPU LLVM IR generator currently uses 64 bit integer type for arithmetic operations related to loop index and tensor element index and relies on LLVM optimization to narrow the operations to 32 bit integer type. There are situations whether LLVM optimization fail to perform such an optimization, see LLVM D46760 for more detail. This change modifies the XLA LLVM IR code generation infrastructure to support the use of 32 bit integer type for loop index and tensor element index as follows: .Extends the loop emitter interface in ParallelLoopEmitter and ForLoopNest to allow users to specify the loop index type. .Modifies the tensor access interface in IrArray::Index interface to record the llvm type for the index when an object is constructed. This index type is usually propagated from a loop index type. .Modifies kernel_support_library to retrieve the loop index type from the input llvm::Value. .Modifies elemental_ir_emitter to retrieve the data type from the input IrArray::Index and use it tensor offset expression. This change also modifies the emission of the fusion kernel, the row and scalar reduction kernel and SelectAndScatter kernel to use 32 bit integer type for index calculation when the size of the launch dimension and the size of tensors used in the kernel are within the range of 32 bit integer representation. PiperOrigin-RevId: 201303468
Diffstat (limited to 'tensorflow/compiler/xla/service/elemental_ir_emitter.cc')
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.cc104
1 files changed, 60 insertions, 44 deletions
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index 93fea7ead7..4ccd85307d 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -1220,7 +1220,7 @@ llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex(
const Shape& operand_shape = hlo.operand(operand_no)->shape();
// If the operand is scalar, the source index is always {}.
if (ShapeUtil::IsScalar(operand_shape)) {
- return llvm_ir::IrArray::Index();
+ return llvm_ir::IrArray::Index(target_index.GetType());
}
// If no implicit broadcast is needed for this operand, returns the target
@@ -1232,13 +1232,13 @@ llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex(
// If implicit broadcast is needed, the source dimensions that are broadcast
// have index 0.
CHECK_EQ(ShapeUtil::Rank(operand_shape), ShapeUtil::Rank(hlo.shape()));
- llvm_ir::IrArray::Index source_index;
+ llvm_ir::IrArray::Index source_index(target_index.GetType());
for (int64 i = 0; i < ShapeUtil::Rank(hlo.shape()); ++i) {
if (hlo.shape().dimensions(i) == operand_shape.dimensions(i)) {
source_index.push_back(target_index[i]);
} else {
CHECK_EQ(1, operand_shape.dimensions(i));
- source_index.push_back(ir_builder_->getInt64(0));
+ source_index.push_back(target_index.GetConstantWithIndexType(0));
}
}
return source_index;
@@ -1540,9 +1540,14 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicSlice(
// Emit IR to read dynamic start indices from hlo->operand(1).
const HloInstruction* input_hlo = hlo->operand(0);
const int64 rank = ShapeUtil::Rank(input_hlo->shape());
- llvm_ir::IrArray::Index slice_start_index(rank);
+ // Use the same index type for all tensor accesses in the same kernel.
+ llvm::Type* index_type = index.GetType();
+ llvm_ir::IrArray::Index slice_start_index(index_type, rank);
for (int64 i = 0; i < rank; ++i) {
- llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i));
+ auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
+ return llvm::ConstantInt::get(index_type, c);
+ };
+ llvm_ir::IrArray::Index dim_index(1, index_typed_const(i));
TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value,
operand_to_generator.at(hlo->operand(1))(dim_index));
@@ -1552,17 +1557,17 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicSlice(
// TODO(b/74360564): This is implementation defined behavior, but is
// currently respected by all implementations. Change this if we ever decide
// to oficially document different behavior.
- start_index_value = ir_builder_->CreateSExtOrBitCast(start_index_value,
- index[i]->getType());
- llvm::Value* operand_dim_size = llvm::ConstantInt::get(
- start_index_value->getType(), input_hlo->shape().dimensions(i));
- llvm::Value* output_dim_size = llvm::ConstantInt::get(
- start_index_value->getType(), hlo->shape().dimensions(i));
+ start_index_value =
+ ir_builder_->CreateSExtOrTrunc(start_index_value, index_type);
+ llvm::Value* operand_dim_size =
+ index_typed_const(input_hlo->shape().dimensions(i));
+ llvm::Value* output_dim_size =
+ index_typed_const(hlo->shape().dimensions(i));
start_index_value = EmitIntegralMin(
ir_builder_->CreateSub(operand_dim_size, output_dim_size),
- EmitIntegralMax(llvm::ConstantInt::get(start_index_value->getType(), 0),
- start_index_value, /*is_signed=*/true),
+ EmitIntegralMax(index_typed_const(0), start_index_value,
+ /*is_signed=*/true),
/*is_signed=*/true);
start_index_value->setName(
@@ -1570,7 +1575,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicSlice(
slice_start_index[i] = start_index_value;
}
- llvm_ir::IrArray::Index input_index(rank);
+ llvm_ir::IrArray::Index input_index(index_type, rank);
for (int64 i = 0; i < rank; ++i) {
// Emit IR which computes:
// input_index = start_index + offset_index
@@ -1594,17 +1599,18 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather(
const llvm_ir::ElementGenerator& indices_generator =
operand_to_generator.at(hlo->operand(1));
+ llvm::Type* index_type = index.GetType();
// This is the index into `operand` that holds the element we want to
// generate. This index "unsafe" as in the components in here may be
// out of bounds.
- IrArray::Index unsafe_operand_index;
+ IrArray::Index unsafe_operand_index(index_type);
// First copy in the window indices to unsafe_operand_index.
for (int64 i = 0, e = operand_shape.dimensions_size(),
unsafe_operand_index_dim = 0;
i < e; i++) {
if (c_binary_search(dim_numbers.elided_window_dims(), i)) {
- unsafe_operand_index.push_back(ir_builder_->getInt64(0));
+ unsafe_operand_index.push_back(index.GetConstantWithIndexType(0));
} else {
unsafe_operand_index.push_back(
index[dim_numbers.output_window_dims(unsafe_operand_index_dim++)]);
@@ -1612,7 +1618,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather(
}
// This is the index of the index vector in the gather_indices tensor.
- IrArray::Index gather_index_index;
+ IrArray::Index gather_index_index(index_type);
{
std::vector<llvm::Value*> gather_index_index_components;
for (int64 i = 0, e = output_shape.dimensions_size(); i < e; i++) {
@@ -1628,8 +1634,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather(
auto add_to_unsafe_operand_index = [&](llvm::Value* index_component,
int64 dim) {
- llvm::Value* gather_dim_component_extended = ir_builder_->CreateSExtOrTrunc(
- index_component, ir_builder_->getInt64Ty());
+ llvm::Value* gather_dim_component_extended =
+ ir_builder_->CreateSExtOrTrunc(index_component, index_type);
unsafe_operand_index[dim_numbers.gather_dims_to_operand_dims(dim)] =
ir_builder_->CreateAdd(
unsafe_operand_index[dim_numbers.gather_dims_to_operand_dims(dim)],
@@ -1645,18 +1651,18 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather(
indices_shape.dimensions(dim_numbers.index_vector_dim());
for (int64 i = 0; i < index_vector_size; i++) {
gather_index_index[dim_numbers.index_vector_dim()] =
- ir_builder_->getInt64(i);
+ index.GetConstantWithIndexType(i);
TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component,
indices_generator(gather_index_index));
add_to_unsafe_operand_index(gather_dim_component, i);
}
}
- IrArray::Index safe_operand_index;
+ IrArray::Index safe_operand_index(index_type);
for (int64 i = 0, e = unsafe_operand_index.size(); i < e; i++) {
safe_operand_index.push_back(ir_builder_->CreateURem(
unsafe_operand_index[i],
- ir_builder_->getInt64(operand_shape.dimensions(i))));
+ index.GetConstantWithIndexType(operand_shape.dimensions(i))));
}
return operand_generator(safe_operand_index);
@@ -1671,14 +1677,18 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice(
const HloInstruction* start_hlo = hlo->operand(2);
// Calculate slice start/end indices.
const int64 rank = ShapeUtil::Rank(input_hlo->shape());
- llvm_ir::IrArray::Index slice_start_index(rank);
- llvm_ir::IrArray::Index slice_limit_index(rank);
+ llvm_ir::IrArray::Index slice_start_index(index.GetType(), rank);
+ llvm_ir::IrArray::Index slice_limit_index(index.GetType(), rank);
// Slice intersection gathers (ANDs) conditions on all ranks for which
// 'input' is set to 'update'
llvm::Value* slice_intersection = ir_builder_->getTrue();
for (int64 i = 0; i < rank; ++i) {
- llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i));
+ llvm::Type* index_type = index[0]->getType();
+ auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
+ return llvm::ConstantInt::get(index_type, c);
+ };
+ llvm_ir::IrArray::Index dim_index(1, index_typed_const(i));
TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value,
operand_to_generator.at(start_hlo)(dim_index));
@@ -1688,18 +1698,18 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice(
// TODO(b/74360564): This is implementation defined behavior, but is
// currently respected by all implementations. Change this if we ever decide
// to oficially document different behavior.
- start_index_value = ir_builder_->CreateSExtOrBitCast(start_index_value,
- index[i]->getType());
- llvm::Value* input_dim_size = llvm::ConstantInt::get(
- index[i]->getType(), input_hlo->shape().dimensions(i));
- llvm::Value* update_dim_size = llvm::ConstantInt::get(
- index[i]->getType(), update_hlo->shape().dimensions(i));
-
- start_index_value = EmitIntegralMin(
- ir_builder_->CreateSub(input_dim_size, update_dim_size),
- EmitIntegralMax(llvm::ConstantInt::get(start_index_value->getType(), 0),
- start_index_value, /*is_signed=*/true),
- /*is_signed=*/true);
+ start_index_value =
+ ir_builder_->CreateSExtOrTrunc(start_index_value, index_type);
+ llvm::Value* input_dim_size =
+ index_typed_const(input_hlo->shape().dimensions(i));
+ llvm::Value* update_dim_size =
+ index_typed_const(update_hlo->shape().dimensions(i));
+
+ start_index_value =
+ EmitIntegralMin(ir_builder_->CreateSub(input_dim_size, update_dim_size),
+ EmitIntegralMax(index_typed_const(0), start_index_value,
+ /*is_signed=*/true),
+ /*is_signed=*/true);
start_index_value->setName(
AsStringRef(IrName(hlo, StrCat("start_idx", i))));
@@ -1729,7 +1739,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice(
// Handle true BB (return data from 'update')
SetToFirstInsertPoint(if_data.true_block, ir_builder_);
// Compute update index for intersection case.
- llvm_ir::IrArray::Index update_index(rank);
+ llvm_ir::IrArray::Index update_index(index.GetType(), rank);
for (int64 i = 0; i < rank; ++i) {
update_index[i] = ir_builder_->CreateSub(index[i], slice_start_index[i]);
}
@@ -1797,7 +1807,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalPad(
SetToFirstInsertPoint(if_data.false_block, ir_builder_);
TF_ASSIGN_OR_RETURN(llvm::Value * padding_value,
- operand_to_generator.at(hlo->operand(1))({}));
+ operand_to_generator.at(hlo->operand(1))(
+ IrArray::Index(index.GetType())));
ir_builder_->CreateStore(padding_value, ret_value_addr);
SetToFirstInsertPoint(if_data.after_block, ir_builder_);
@@ -1824,10 +1835,15 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDot(
int64 lhs_dims = hlo->operand(0)->shape().dimensions_size();
int64 rhs_dims = hlo->operand(1)->shape().dimensions_size();
- std::unique_ptr<llvm_ir::ForLoop> inner_loop = llvm_ir::ForLoop::EmitForLoop(
- IrName(hlo, "inner"), ir_builder_->getInt64(0),
- ir_builder_->getInt64(contracted_dim_size), ir_builder_->getInt64(1),
- ir_builder_);
+ llvm::Type* index_type = dot_result_index[0]->getType();
+ auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
+ return llvm::ConstantInt::get(index_type, c);
+ };
+
+ std::unique_ptr<llvm_ir::ForLoop> inner_loop =
+ llvm_ir::ForLoop::EmitForLoop(IrName(hlo, "inner"), index_typed_const(0),
+ index_typed_const(contracted_dim_size),
+ index_typed_const(1), ir_builder_);
SetToFirstInsertPoint(inner_loop->GetPreheaderBasicBlock(), ir_builder_);
PrimitiveType primitive_type = hlo->shape().element_type();
@@ -1846,7 +1862,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDot(
// Given an output index [a,b,c,d,e] in the result, we compute:
// sum(lhs[a,b,c,t]*rhs[d,t,e] for t in [0, T))
- IrArray::Index lhs_index, rhs_index;
+ IrArray::Index lhs_index(index_type), rhs_index(index_type);
for (int64 i = 0; i < lhs_dims - 1; i++) {
lhs_index.push_back(dot_result_index[i]);