diff options
author | Benjamin Kramer <kramerb@google.com> | 2018-09-07 00:52:54 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-07 00:56:13 -0700 |
commit | 424de2b5279bf3779c27a39403f94281f3460543 (patch) | |
tree | fc5e822a8fd60bf3aaf7d49b4e329fc01ecd4392 /tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc | |
parent | debd66dae1c9a49d36ea006c97facf06b4ac25cb (diff) |
[XLA:GPU] Clean up init thunk handling to handle arbitrary fused init values
I put this in as a quick hack because init_value is usually a constant, but
it's really easy to construct a case where it's not. The code also became more
complex because of the constant buffer work, sharing that with the fused IR
emitter is a good thing.
PiperOrigin-RevId: 211936337
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc | 59 |
1 files changed, 33 insertions, 26 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 0c7623fd79..f91cc00d71 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -2521,15 +2521,15 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildFftThunk( } StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk( - const HloInstruction* hlo, const ShapeIndex& index) { + HloInstruction* hlo, const ShapeIndex& index) { bool fused = HloOpcode::kFusion == hlo->opcode(); - const HloInstruction* inst = fused ? hlo->fused_expression_root() : hlo; - const HloInstruction* init_value_operand = [&] { + HloInstruction* inst = fused ? hlo->fused_expression_root() : hlo; + HloInstruction* init_value_operand = [&] { switch (inst->opcode()) { case HloOpcode::kSelectAndScatter: - return inst->operand(2); + return inst->mutable_operand(2); case HloOpcode::kReduce: - return inst->operand(1); + return inst->mutable_operand(1); case HloOpcode::kTuple: CHECK(hlo->IsMultiOutputFusion()) << ": " << hlo->ToString() << " is not a multi-output fusion."; @@ -2537,7 +2537,7 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk( << ": Found '" << inst->operand(index.back())->opcode() << "' in " << inst->ToString() << " but expected 'reduce'."; // For multi-output fusion look through the tuple. - return inst->operand(index.back())->operand(1); + return inst->mutable_operand(index.back())->mutable_operand(1); default: LOG(FATAL) << "Opcode " << inst->opcode() << " should not need an initializer."; @@ -2609,28 +2609,35 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk( ir_emitter_context_->device_description()); UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(), ir_emitter_context_->llvm_module()); - // If the init_value was fused into this reduce we have to generate it first. - if (fused && init_value_operand->opcode() != HloOpcode::kParameter) { - CHECK_EQ(HloOpcode::kConstant, init_value_operand->opcode()); - const Literal& literal = init_value_operand->literal(); - llvm::Constant* initializer = - llvm_ir::ConvertLiteralToIrConstant(literal, module_); + if (fused) { + // If init_value was fused into this reduce we have to generate it first. + std::vector<IrArray> parameter_arrays; + for (HloInstruction* operand : hlo->operands()) { + parameter_arrays.push_back(GetIrArray(*operand, *hlo)); + } + GpuElementalIrEmitter elemental_emitter(hlo_module_config_, + ir_emitter_context_->llvm_module(), + &b_, GetNestedComputer()); - llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable( - *module_, initializer->getType(), - /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, initializer, - /*Name=*/""); - global_for_const->setAlignment(kConstantBufferAlignBytes); - bindings_.BindHloToIrValue(*init_value_operand, global_for_const); - } - TF_RETURN_IF_ERROR(ParallelLoopEmitter( - [=](const IrArray::Index& index) { - return GetIrArray(*init_value, *hlo) - .EmitReadArrayElement(index, &b_); - }, - GetIrArray(*hlo, *hlo, index), launch_dimensions, &b_) - .EmitLoop(IrName(hlo))); + FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter); + TF_RETURN_IF_ERROR(init_value_operand->Accept(&fused_emitter)); + TF_RETURN_IF_ERROR( + ParallelLoopEmitter(fused_emitter.GetGenerator(init_value_operand), + GetIrArray(*hlo, *hlo, index), launch_dimensions, + &b_) + .EmitLoop(IrName(hlo))); + } else { + // In the unfused case the element is already there, just read from it. + TF_RETURN_IF_ERROR(ParallelLoopEmitter( + [=](const IrArray::Index& index) { + return GetIrArray(*init_value, *hlo) + .EmitReadArrayElement(index, &b_); + }, + GetIrArray(*hlo, *hlo, index), launch_dimensions, + &b_) + .EmitLoop(IrName(hlo))); + } // Clean up state left behind by emitting the loop above. (This is normally // done in IrEmitterUnnested::Postprocess().) |