aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
diff options
context:
space:
mode:
authorGravatar Justin Lebar <jlebar@google.com>2018-07-11 09:40:48 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-11 09:44:11 -0700
commitd71be4d5febada6af32f3286ad2f4ec61cefb1b3 (patch)
tree8c71b02ef23b3134919b542df3a18b1f30546fbf /tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
parent5ff365a4cebce6f50f4cfeeab7490992f6089961 (diff)
[XLA:GPU] s/llvm_ir::IrArray/IrArray/ in ir_emitter_unnested.
Less visual noise. PiperOrigin-RevId: 204139183
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc112
1 files changed, 50 insertions, 62 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index a97447860e..673ba530df 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -595,7 +595,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
BuildKernelThunk(fusion, /*implements_whole_instruction=*/false));
thunk_sequence_->emplace_back(
MakeUnique<SequentialThunk>(std::move(thunks), fusion));
- std::vector<llvm_ir::IrArray> parameter_arrays;
+ std::vector<IrArray> parameter_arrays;
for (HloInstruction* operand : fusion->operands()) {
parameter_arrays.push_back(GetIrArray(*operand, *fusion));
}
@@ -668,7 +668,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
// Set up kernel thunk and fused ir emitter.
thunk_sequence_->emplace_back(
BuildKernelThunk(fusion, /*implements_whole_instruction=*/true));
- std::vector<llvm_ir::IrArray> operand_arrays;
+ std::vector<IrArray> operand_arrays;
for (HloInstruction* operand : fusion->operands()) {
operand_arrays.push_back(GetIrArray(*operand, *fusion));
}
@@ -681,7 +681,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
// Array to write into. Because this is an in-place operation, this is the
// same as operand 0's array.
- llvm_ir::IrArray output_array = GetIrArray(*fusion, *fusion);
+ IrArray output_array = GetIrArray(*fusion, *fusion);
LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
update_shape, ir_emitter_context_->device_description());
@@ -732,7 +732,7 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) {
}
Status IrEmitterUnnested::EmitExtraOutputsForReduce(
- const HloInstruction* reduce, const llvm_ir::IrArray::Index& index,
+ const HloInstruction* reduce, const IrArray::Index& index,
tensorflow::gtl::ArraySlice<
std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
extra_output_gens) {
@@ -819,8 +819,7 @@ Status IrEmitterUnnested::EmitReductionToScalar(
// // and threads_per_block is a multiple of warpSize.
// reduce_kernel<<<num_blocks, threads_per_block>>>();
//
- auto loop_body_emitter =
- [=](const llvm_ir::IrArray::Index& tile_index) -> Status {
+ auto loop_body_emitter = [=](const IrArray::Index& tile_index) -> Status {
const int num_reduces = reducers.size();
llvm::Type* element_ir_type =
llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_);
@@ -829,9 +828,8 @@ Status IrEmitterUnnested::EmitReductionToScalar(
llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca(
element_ir_type, /*ArraySize=*/nullptr,
"partial_reduction_result." + llvm::Twine(i));
- TF_ASSIGN_OR_RETURN(
- llvm::Value* const init_ir_value,
- init_value_gens[i](llvm_ir::IrArray::Index(index_ty)));
+ TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value,
+ init_value_gens[i](IrArray::Index(index_ty)));
ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address);
partial_reduction_result_addresses.push_back(
partial_reduction_result_address);
@@ -866,7 +864,7 @@ Status IrEmitterUnnested::EmitReductionToScalar(
llvm_ir::SetToFirstInsertPoint(if_data.true_block, &ir_builder_);
}
- llvm_ir::IrArray::Index input_index(
+ IrArray::Index input_index(
/*linear=*/x, input_shape, &ir_builder_);
llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type);
for (int i = 0; i != num_reduces; ++i) {
@@ -951,7 +949,7 @@ Status IrEmitterUnnested::EmitReductionToScalar(
llvm::Value* output_address =
GetIrArray(*output, *output, reduce_output_shapes[i])
.EmitArrayElementAddress(
- llvm_ir::IrArray::Index(
+ IrArray::Index(
/*linear=*/ir_builder_.getInt64(0),
ShapeUtil::GetSubshape(output->shape(),
reduce_output_shapes[i]),
@@ -1037,8 +1035,7 @@ Status IrEmitterUnnested::EmitColumnReduction(
// }
// AtomicReducer(&output[x], partial_result);
// }
- auto loop_body_emitter =
- [=](const llvm_ir::IrArray::Index& tile_index) -> Status {
+ auto loop_body_emitter = [=](const IrArray::Index& tile_index) -> Status {
const int num_reduces = reducers.size();
// Emit the loop body that reduces one tile.
llvm::Type* element_ir_type =
@@ -1048,9 +1045,8 @@ Status IrEmitterUnnested::EmitColumnReduction(
llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca(
element_ir_type, /*ArraySize=*/nullptr,
"partial_reduction_result." + llvm::Twine(i));
- TF_ASSIGN_OR_RETURN(
- llvm::Value* const init_ir_value,
- init_value_gens[i](llvm_ir::IrArray::Index(index_ty)));
+ TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value,
+ init_value_gens[i](IrArray::Index(index_ty)));
ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address);
partial_reduction_result_addresses.push_back(
partial_reduction_result_address);
@@ -1106,9 +1102,9 @@ Status IrEmitterUnnested::EmitColumnReduction(
const Shape input_matrix_shape =
ShapeUtil::MakeShapeWithDescendingLayout(input_shape.element_type(),
{height, width});
- const llvm_ir::IrArray::Index input_matrix_index(
- {y, x}, input_matrix_shape, &ir_builder_);
- const llvm_ir::IrArray::Index input_index =
+ const IrArray::Index input_matrix_index({y, x}, input_matrix_shape,
+ &ir_builder_);
+ const IrArray::Index input_index =
input_matrix_index
.SourceIndexOfReshape(input_matrix_shape,
normalized_input_shape, &ir_builder_)
@@ -1159,11 +1155,10 @@ Status IrEmitterUnnested::EmitColumnReduction(
llvm::Value* output_address =
GetIrArray(*output, *output, reduce_output_shapes[i])
.EmitArrayElementAddress(
- llvm_ir::IrArray::Index(
- x,
- ShapeUtil::GetSubshape(output->shape(),
- reduce_output_shapes[i]),
- &ir_builder_),
+ IrArray::Index(x,
+ ShapeUtil::GetSubshape(
+ output->shape(), reduce_output_shapes[i]),
+ &ir_builder_),
&ir_builder_, "output_element_address");
TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation(
*reducers[i], output_address, partial_reduction_result_addresses[i]));
@@ -1335,7 +1330,7 @@ Status IrEmitterUnnested::EmitRowReduction(
return llvm::ConstantInt::get(index_ty, c);
};
- auto loop_body_emitter = [=](const llvm_ir::IrArray::Index& tile_index) {
+ auto loop_body_emitter = [=](const IrArray::Index& tile_index) {
const int num_reduces = reducers.size();
llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType(
input_shape.element_type(), ir_emitter_context_->llvm_module());
@@ -1344,9 +1339,8 @@ Status IrEmitterUnnested::EmitRowReduction(
llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca(
element_ir_type, /*ArraySize=*/nullptr,
"partial_reduction_result." + llvm::Twine(i));
- TF_ASSIGN_OR_RETURN(
- llvm::Value* const init_ir_value,
- init_value_gens[i](llvm_ir::IrArray::Index(index_ty)));
+ TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value,
+ init_value_gens[i](IrArray::Index(index_ty)));
ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address);
partial_reduction_result_addresses.push_back(
partial_reduction_result_address);
@@ -1435,9 +1429,9 @@ Status IrEmitterUnnested::EmitRowReduction(
const Shape input_3d_tensor_shape =
ShapeUtil::MakeShapeWithDescendingLayout(
input_shape.element_type(), {depth, height, width});
- const llvm_ir::IrArray::Index input_3d_tensor_index(
+ const IrArray::Index input_3d_tensor_index(
{z, y, x}, input_3d_tensor_shape, &ir_builder_);
- const llvm_ir::IrArray::Index input_index =
+ const IrArray::Index input_index =
input_3d_tensor_index
.SourceIndexOfReshape(input_3d_tensor_shape,
normalized_input_shape,
@@ -1532,11 +1526,10 @@ Status IrEmitterUnnested::EmitRowReduction(
llvm::Value* output_address =
GetIrArray(*output, *output, reduce_output_shapes[i])
.EmitArrayElementAddress(
- llvm_ir::IrArray::Index(
- y,
- ShapeUtil::GetSubshape(output->shape(),
- reduce_output_shapes[i]),
- &ir_builder_),
+ IrArray::Index(y,
+ ShapeUtil::GetSubshape(
+ output->shape(), reduce_output_shapes[i]),
+ &ir_builder_),
&ir_builder_, "output_element_address");
// We don't need to emit atomic operations if there is only one tile of
// results. 'depth' is the z dimension, 'width' is the x dimension.
@@ -1686,11 +1679,11 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) {
MakeUnique<SequentialThunk>(std::move(thunks), reduce));
return EmitReductionToVector(
- reduce, input->shape(), {[&](const llvm_ir::IrArray::Index& index) {
+ reduce, input->shape(), {[&](const IrArray::Index& index) {
return GetIrArray(*input, *reduce)
.EmitReadArrayElement(index, &ir_builder_);
}},
- {[&](const llvm_ir::IrArray::Index& index) {
+ {[&](const IrArray::Index& index) {
return GetIrArray(*init_value, *reduce)
.EmitReadArrayElement(index, &ir_builder_);
}},
@@ -1791,8 +1784,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
// selected_index = I
// initialized_flag = true
// output(selected_index) = scatter(output(selected_index), source(S))
- auto loop_body_emitter =
- [=](const llvm_ir::IrArray::Index& source_index) -> Status {
+ auto loop_body_emitter = [=](const IrArray::Index& source_index) -> Status {
// Allocate space to keep the currently selected value, its index, and a
// boolean flag if the value is initialized. The initialized_flag is set
// false.
@@ -1817,7 +1809,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
window_size.push_back(dim.size());
CHECK_GT(dim.size(), 0);
}
- const llvm_ir::IrArray::Index window_index = window_loops.AddLoopsForShape(
+ const IrArray::Index window_index = window_loops.AddLoopsForShape(
ShapeUtil::MakeShape(operand_element_type, window_size), "window");
llvm_ir::SetToFirstInsertPoint(window_loops.GetInnerLoopBodyBasicBlock(),
&ir_builder_);
@@ -1825,7 +1817,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
// Compute the operand index to visit and evaluate the condition whether the
// operand index is within the bounds. The unsigned comparison includes
// checking whether the operand index >= 0.
- llvm_ir::IrArray::Index operand_index(index_type, source_index.size());
+ IrArray::Index operand_index(index_type, source_index.size());
llvm::Value* in_bounds_condition = ir_builder_.getInt1(true);
for (int64 i = 0; i < rank; ++i) {
llvm::Value* strided_index = ir_builder_.CreateNSWMul(
@@ -1853,8 +1845,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
// If the initialized_flag is false, initialize the selected value and index
// with the currently visiting operand.
llvm_ir::SetToFirstInsertPoint(if_initialized.false_block, &ir_builder_);
- const auto save_operand_index = [&](
- const llvm_ir::IrArray::Index& operand_index) {
+ const auto save_operand_index = [&](const IrArray::Index& operand_index) {
for (int64 i = 0; i < rank; ++i) {
llvm::Value* selected_index_address_slot =
ir_builder_.CreateInBoundsGEP(selected_index_address,
@@ -1862,7 +1853,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
ir_builder_.CreateStore(operand_index[i], selected_index_address_slot);
}
};
- llvm_ir::IrArray operand_array = GetIrArray(*operand, *select_and_scatter);
+ IrArray operand_array = GetIrArray(*operand, *select_and_scatter);
llvm::Value* operand_data =
operand_array.EmitReadArrayElement(operand_index, &ir_builder_);
ir_builder_.CreateStore(operand_data, selected_value_address);
@@ -1907,7 +1898,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
// value and the current output value.
llvm_ir::SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(),
&ir_builder_);
- llvm_ir::IrArray::Index selected_index(operand_index.GetType());
+ IrArray::Index selected_index(operand_index.GetType());
for (int64 i = 0; i < rank; ++i) {
llvm::Value* selected_index_address_slot = ir_builder_.CreateInBoundsGEP(
selected_index_address, {ir_builder_.getInt32(i)});
@@ -2492,7 +2483,7 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
TF_RETURN_IF_ERROR(HandleConstant(const_cast<HloInstruction*>(init_value)));
}
TF_RETURN_IF_ERROR(ParallelLoopEmitter(
- [=](const llvm_ir::IrArray::Index& index) {
+ [=](const IrArray::Index& index) {
return GetIrArray(*init_value, *hlo)
.EmitReadArrayElement(index, &ir_builder_);
},
@@ -2688,7 +2679,7 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
}
// For multioutput fusion, we need to emit each operand and the root.
- std::vector<llvm_ir::IrArray> output_arrays;
+ std::vector<IrArray> output_arrays;
for (int64 i = 0; i < ShapeUtil::TupleElementCount(hlo.shape()); ++i) {
output_arrays.push_back(GetIrArray(hlo, hlo, {i}));
}
@@ -2718,7 +2709,7 @@ Status IrEmitterUnnested::EmitTargetElementLoop(
}
int IrEmitterUnnested::ConstructIrArrayForOutputs(
- const HloInstruction& hlo, std::vector<llvm_ir::IrArray>* output_arrays) {
+ const HloInstruction& hlo, std::vector<IrArray>* output_arrays) {
int64 num_outputs = 1;
if (hlo.IsMultiOutputFusion()) {
num_outputs = ShapeUtil::TupleElementCount(hlo.shape());
@@ -2733,7 +2724,7 @@ int IrEmitterUnnested::ConstructIrArrayForOutputs(
}
int IrEmitterUnnested::ConstructIrArrayForInputs(
- const HloInstruction& hlo, std::vector<llvm_ir::IrArray>* param_arrays) {
+ const HloInstruction& hlo, std::vector<IrArray>* param_arrays) {
int64 num_params = hlo.operands().size();
param_arrays->reserve(num_params);
for (const HloInstruction* param : hlo.operands()) {
@@ -2743,11 +2734,10 @@ int IrEmitterUnnested::ConstructIrArrayForInputs(
}
int IrEmitterUnnested::ConstructOutputReducedShapeAndCastOutputIrArrayToShape(
- const HloInstruction& hlo,
- const std::vector<llvm_ir::IrArray>& output_arrays,
+ const HloInstruction& hlo, const std::vector<IrArray>& output_arrays,
tensorflow::gtl::ArraySlice<int64> reduced_output_dims,
std::vector<Shape>* output_reduced_shapes,
- std::vector<llvm_ir::IrArray>* output_in_reduced_shape_arrays) {
+ std::vector<IrArray>* output_in_reduced_shape_arrays) {
int64 num_outputs = 1;
if (hlo.IsMultiOutputFusion()) {
num_outputs = ShapeUtil::TupleElementCount(hlo.shape());
@@ -2770,19 +2760,18 @@ int IrEmitterUnnested::ConstructOutputReducedShapeAndCastOutputIrArrayToShape(
}
int IrEmitterUnnested::ConstructInputReducedShapeAndCastInputIrArrayToShape(
- const HloInstruction& hlo,
- const std::vector<llvm_ir::IrArray>& param_arrays,
+ const HloInstruction& hlo, const std::vector<IrArray>& param_arrays,
const std::vector<llvm::Value*>& param_buffers,
tensorflow::gtl::ArraySlice<int64> reduced_output_dims,
std::vector<Shape>* param_reduced_shapes,
- std::vector<llvm_ir::IrArray>* param_in_reduced_shape_arrays) {
+ std::vector<IrArray>* param_in_reduced_shape_arrays) {
int64 num_params = hlo.operands().size();
param_in_reduced_shape_arrays->reserve(num_params);
param_reduced_shapes->reserve(num_params);
for (int64 id = 0; id < num_params; ++id) {
if (param_buffers[id] == nullptr) {
param_reduced_shapes->push_back(Shape());
- param_in_reduced_shape_arrays->push_back(llvm_ir::IrArray());
+ param_in_reduced_shape_arrays->push_back(IrArray());
continue;
}
const HloInstruction* param = hlo.operand(id);
@@ -2835,11 +2824,11 @@ llvm::Value* GetBlockIdx(llvm::IRBuilder<>* builder, llvm::Type* index_ty,
// processed element is within the boundary defined by `tile_width` and
// `tile_height`.
void EmitTiledElementalCodeWithBoundsCheck(
- int64 tile_size, int64 num_rows, const llvm_ir::IrArray::Index& index,
+ int64 tile_size, int64 num_rows, const IrArray::Index& index,
const string& loop_name, KernelSupportLibrary* ksl,
llvm::IRBuilder<>* builder, llvm::Value* y, llvm::Value* x,
llvm::Value* tile_width, llvm::Value* tile_height,
- const std::function<void(const llvm_ir::IrArray::Index&, llvm::Value*)>&
+ const std::function<void(const IrArray::Index&, llvm::Value*)>&
emit_elem_function) {
llvm::Type* index_ty = tile_width->getType();
// Emits a constant value with index type.
@@ -2847,8 +2836,7 @@ void EmitTiledElementalCodeWithBoundsCheck(
return llvm::ConstantInt::get(index_ty, c);
};
// Adds `addend` to the given `dim` of `index`.
- auto offset_dim = [&](llvm_ir::IrArray::Index index, llvm::Value* addend,
- int64 dim) {
+ auto offset_dim = [&](IrArray::Index index, llvm::Value* addend, int64 dim) {
index[dim] = builder->CreateAdd(index[dim], addend);
return index;
};
@@ -3037,8 +3025,8 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
auto emit_tiled_elemental_code_with_bounds_check =
[&](const IrArray::Index& index, const string& loop_name,
llvm::Value* tile_width, llvm::Value* tile_height,
- const std::function<void(const llvm_ir::IrArray::Index&,
- llvm::Value*)>& emit_elem_function) {
+ const std::function<void(const IrArray::Index&, llvm::Value*)>&
+ emit_elem_function) {
EmitTiledElementalCodeWithBoundsCheck(
kTileSize, kNumRows, index, loop_name, &ksl, &ir_builder_, y, x,
tile_width, tile_height, emit_elem_function);