aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
diff options
context:
space:
mode:
authorGravatar Benjamin Kramer <kramerb@google.com>2018-09-07 00:52:54 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-07 00:56:13 -0700
commit424de2b5279bf3779c27a39403f94281f3460543 (patch)
treefc5e822a8fd60bf3aaf7d49b4e329fc01ecd4392 /tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
parentdebd66dae1c9a49d36ea006c97facf06b4ac25cb (diff)
[XLA:GPU] Clean up init thunk handling to handle arbitrary fused init values
I put this in as a quick hack because init_value is usually a constant, but it's really easy to construct a case where it's not. The code also became more complex because of the constant buffer work, sharing that with the fused IR emitter is a good thing. PiperOrigin-RevId: 211936337
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc59
1 files changed, 33 insertions, 26 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 0c7623fd79..f91cc00d71 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -2521,15 +2521,15 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildFftThunk(
}
StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
- const HloInstruction* hlo, const ShapeIndex& index) {
+ HloInstruction* hlo, const ShapeIndex& index) {
bool fused = HloOpcode::kFusion == hlo->opcode();
- const HloInstruction* inst = fused ? hlo->fused_expression_root() : hlo;
- const HloInstruction* init_value_operand = [&] {
+ HloInstruction* inst = fused ? hlo->fused_expression_root() : hlo;
+ HloInstruction* init_value_operand = [&] {
switch (inst->opcode()) {
case HloOpcode::kSelectAndScatter:
- return inst->operand(2);
+ return inst->mutable_operand(2);
case HloOpcode::kReduce:
- return inst->operand(1);
+ return inst->mutable_operand(1);
case HloOpcode::kTuple:
CHECK(hlo->IsMultiOutputFusion())
<< ": " << hlo->ToString() << " is not a multi-output fusion.";
@@ -2537,7 +2537,7 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
<< ": Found '" << inst->operand(index.back())->opcode() << "' in "
<< inst->ToString() << " but expected 'reduce'.";
// For multi-output fusion look through the tuple.
- return inst->operand(index.back())->operand(1);
+ return inst->mutable_operand(index.back())->mutable_operand(1);
default:
LOG(FATAL) << "Opcode " << inst->opcode()
<< " should not need an initializer.";
@@ -2609,28 +2609,35 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
ir_emitter_context_->device_description());
UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
ir_emitter_context_->llvm_module());
- // If the init_value was fused into this reduce we have to generate it first.
- if (fused && init_value_operand->opcode() != HloOpcode::kParameter) {
- CHECK_EQ(HloOpcode::kConstant, init_value_operand->opcode());
- const Literal& literal = init_value_operand->literal();
- llvm::Constant* initializer =
- llvm_ir::ConvertLiteralToIrConstant(literal, module_);
+ if (fused) {
+ // If init_value was fused into this reduce we have to generate it first.
+ std::vector<IrArray> parameter_arrays;
+ for (HloInstruction* operand : hlo->operands()) {
+ parameter_arrays.push_back(GetIrArray(*operand, *hlo));
+ }
+ GpuElementalIrEmitter elemental_emitter(hlo_module_config_,
+ ir_emitter_context_->llvm_module(),
+ &b_, GetNestedComputer());
- llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable(
- *module_, initializer->getType(),
- /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, initializer,
- /*Name=*/"");
- global_for_const->setAlignment(kConstantBufferAlignBytes);
- bindings_.BindHloToIrValue(*init_value_operand, global_for_const);
- }
- TF_RETURN_IF_ERROR(ParallelLoopEmitter(
- [=](const IrArray::Index& index) {
- return GetIrArray(*init_value, *hlo)
- .EmitReadArrayElement(index, &b_);
- },
- GetIrArray(*hlo, *hlo, index), launch_dimensions, &b_)
- .EmitLoop(IrName(hlo)));
+ FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter);
+ TF_RETURN_IF_ERROR(init_value_operand->Accept(&fused_emitter));
+ TF_RETURN_IF_ERROR(
+ ParallelLoopEmitter(fused_emitter.GetGenerator(init_value_operand),
+ GetIrArray(*hlo, *hlo, index), launch_dimensions,
+ &b_)
+ .EmitLoop(IrName(hlo)));
+ } else {
+ // In the unfused case the element is already there, just read from it.
+ TF_RETURN_IF_ERROR(ParallelLoopEmitter(
+ [=](const IrArray::Index& index) {
+ return GetIrArray(*init_value, *hlo)
+ .EmitReadArrayElement(index, &b_);
+ },
+ GetIrArray(*hlo, *hlo, index), launch_dimensions,
+ &b_)
+ .EmitLoop(IrName(hlo)));
+ }
// Clean up state left behind by emitting the loop above. (This is normally
// done in IrEmitterUnnested::Postprocess().)