diff options
7 files changed, 128 insertions, 32 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 95f8165795..1a18b28cbb 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -180,15 +180,18 @@ cc_library( cc_library( name = "ir_emitter", - srcs = ["ir_emitter.cc"], + srcs = [ + "elemental_ir_emitter.cc", + "ir_emitter.cc", + ], hdrs = [ + "elemental_ir_emitter.h", "ir_emitter.h", ], deps = [ ":cpu_options", ":cpu_runtime", ":dot_op_emitter", - ":elemental_ir_emitter", ":ir_emission_utils", ":simple_orc_jit", "//tensorflow/compiler/xla:shape_util", @@ -526,22 +529,6 @@ cc_library( ) cc_library( - name = "elemental_ir_emitter", - srcs = ["elemental_ir_emitter.cc"], - hdrs = ["elemental_ir_emitter.h"], - deps = [ - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:elemental_ir_emitter", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", - "@llvm//:core", - ], -) - -cc_library( name = "ir_emission_utils", srcs = ["ir_emission_utils.cc"], hdrs = ["ir_emission_utils.h"], diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc index 511f89144a..902309b338 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc @@ -50,14 +50,6 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, return false; } - // Producer or consumer cannot be Map. Maps are technically elementwise but - // of a slightly different form (call instead of a computation). These are not - // yet supported in the CPU backend. - if (producer->opcode() == HloOpcode::kMap || - consumer->opcode() == HloOpcode::kMap) { - return false; - } - // Cost condition: not fuse (simple, expensive producers) and (consumers who // reuse operand elements). if (producer->opcode() != HloOpcode::kFusion && diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index 0fc62281a0..b56466d5e4 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -209,6 +209,31 @@ class OpcodeFusionTest : public InstructionFusionTest { std::multiset<HloOpcode>(fused_opcodes.begin(), fused_opcodes.end()), expected_opcodes); } + + HloComputation* CreateAdderToOne(HloModule* module) { + HloComputation::Builder builder(TestName()); + HloInstruction* arg0 = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "arg0")); + HloInstruction* one = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); + builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, arg0, one)); + return module->AddEmbeddedComputation(builder.Build()); + } + + HloComputation* CreateMax(HloModule* module) { + HloComputation::Builder builder(TestName()); + HloInstruction* arg0 = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "arg0")); + HloInstruction* arg1 = + builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {}), "arg1")); + builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {}), HloOpcode::kMaximum, arg0, arg1)); + return module->AddEmbeddedComputation(builder.Build()); + } }; TEST_F(OpcodeFusionTest, Exponential_Bitcast_Negate) { @@ -402,6 +427,49 @@ TEST_F(OpcodeFusionTest, Exponential_Transpose_Negate) { HloOpcode::kParameter}); } +TEST_F(OpcodeFusionTest, UnaryMapOfExp) { + auto module = CreateNewModule(); + + HloComputation::Builder builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {3, 4}); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + + HloInstruction* exp = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kExp, param0)); + builder.AddInstruction(HloInstruction::CreateMap( + shape, {exp}, CreateAdderToOne(module.get()), /*static_operands=*/{})); + + module->AddEntryComputation(builder.Build()); + + RunFusionAndCheckOpcodesWereFused( + module.get(), {HloOpcode::kParameter, HloOpcode::kExp, HloOpcode::kMap}); +} + +TEST_F(OpcodeFusionTest, BinaryMapOfExps) { + auto module = CreateNewModule(); + + HloComputation::Builder builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {3, 4}); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, shape, "param")); + + HloInstruction* exp0 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kExp, param0)); + HloInstruction* exp1 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kExp, param1)); + + builder.AddInstruction(HloInstruction::CreateMap( + shape, {exp0, exp1}, CreateMax(module.get()), /*static_operands=*/{})); + + module->AddEntryComputation(builder.Build()); + + RunFusionAndCheckOpcodesWereFused( + module.get(), {HloOpcode::kParameter, HloOpcode::kParameter, + HloOpcode::kExp, HloOpcode::kExp, HloOpcode::kMap}); +} } // namespace } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc index fe447adf89..73e039250b 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc @@ -64,5 +64,25 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitFloatUnaryOp( } } +llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator( + const HloInstruction* hlo, + const HloToElementGeneratorMap& operand_to_generator) const { + if (hlo->opcode() == HloOpcode::kMap) { + return [this, hlo, &operand_to_generator]( + const llvm_ir::IrArray::Index& index) -> StatusOr<llvm::Value*> { + std::vector<llvm::Value*> operands; + for (int i = 0; i < hlo->operand_count(); i++) { + TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, + operand_to_generator.at(hlo->operand(i))( + ElementwiseSourceIndex(index, *hlo, 0))); + operands.push_back(operand_value); + } + return ir_emitter_->EmitScalarCall(hlo->shape().element_type(), + hlo->to_apply(), operands, + llvm_ir::IrName(hlo)); + }; + } + return ElementalIrEmitter::MakeElementGenerator(hlo, operand_to_generator); +} } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h index 6f9d6a24b4..7e9f27befb 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h @@ -19,6 +19,7 @@ limitations under the License. #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/cpu/ir_emitter.h" #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/statusor.h" @@ -29,12 +30,19 @@ namespace cpu { class CpuElementalIrEmitter : public ElementalIrEmitter { public: CpuElementalIrEmitter(const HloModuleConfig& module_config, - llvm::IRBuilder<>* ir_builder, llvm::Module* module) - : ElementalIrEmitter(module_config, module, ir_builder) {} + IrEmitter* ir_emitter, llvm::Module* module) + : ElementalIrEmitter(module_config, module, ir_emitter->ir_builder()), + ir_emitter_(ir_emitter) {} + + llvm_ir::ElementGenerator MakeElementGenerator( + const HloInstruction* hlo, + const HloToElementGeneratorMap& operand_to_generator) const override; protected: StatusOr<llvm::Value*> EmitFloatUnaryOp( const HloInstruction* op, llvm::Value* operand_value) const override; + + IrEmitter* ir_emitter_; }; } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 7d82f33152..8cd8740ee8 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -2354,8 +2354,7 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { for (HloInstruction* operand : fusion->operands()) { parameter_arrays.push_back(GetIrArrayForOp(operand)); } - CpuElementalIrEmitter elemental_emitter(hlo_module_config_, &ir_builder_, - module_); + CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_); FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter); TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter)); @@ -3176,12 +3175,27 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) { return GetIrArrayForOp(operand).EmitReadArrayElement(index, &ir_builder_); }; } - CpuElementalIrEmitter elemental_emitter(hlo_module_config_, &ir_builder_, - module_); + CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_); return EmitTargetElementLoop( hlo, elemental_emitter.MakeElementGenerator(hlo, operand_to_generator)); } +StatusOr<llvm::Value*> IrEmitter::EmitScalarCall( + PrimitiveType return_type, HloComputation* computation, + const std::vector<llvm::Value*>& arguments, tensorflow::StringPiece name) { + llvm::Function* llvm_function = FindOrDie(emitted_functions_, computation); + std::vector<llvm::Value*> argument_addrs; + for (auto argument : arguments) { + llvm::Value* argument_addr = llvm_ir::EmitAllocaAtFunctionEntry( + argument->getType(), "arg_addr", &ir_builder_); + ir_builder_.CreateStore(argument, argument_addr); + argument_addrs.push_back(argument_addr); + } + return EmitElementFunctionCall(llvm_function, + ShapeUtil::MakeShape(return_type, {}), + argument_addrs, name); +} + unsigned TargetMachineFeatures::largest_register_size_in_bytes( llvm::Function* function) { auto itr = largest_register_size_in_bytes_.find(function); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index bcd33c3810..fa33a1eb7b 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -133,6 +133,13 @@ class IrEmitter : public DfsHloVisitorWithDefault { bool is_top_level_computation, std::vector<const HloInstruction*>* instruction_order); + llvm::IRBuilder<>* ir_builder() { return &ir_builder_; } + + // Emits a call to `computation` with scalar arguments `arguments`. + StatusOr<llvm::Value*> EmitScalarCall( + PrimitiveType return_type, HloComputation* computation, + const std::vector<llvm::Value*>& arguments, tensorflow::StringPiece name); + protected: // // The following methods implement the DfsHloVisitor interface. |