aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD23
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc8
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc68
-rw-r--r--tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc20
-rw-r--r--tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h12
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc22
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h7
7 files changed, 128 insertions, 32 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 95f8165795..1a18b28cbb 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -180,15 +180,18 @@ cc_library(
cc_library(
name = "ir_emitter",
- srcs = ["ir_emitter.cc"],
+ srcs = [
+ "elemental_ir_emitter.cc",
+ "ir_emitter.cc",
+ ],
hdrs = [
+ "elemental_ir_emitter.h",
"ir_emitter.h",
],
deps = [
":cpu_options",
":cpu_runtime",
":dot_op_emitter",
- ":elemental_ir_emitter",
":ir_emission_utils",
":simple_orc_jit",
"//tensorflow/compiler/xla:shape_util",
@@ -526,22 +529,6 @@ cc_library(
)
cc_library(
- name = "elemental_ir_emitter",
- srcs = ["elemental_ir_emitter.cc"],
- hdrs = ["elemental_ir_emitter.h"],
- deps = [
- "//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla:types",
- "//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/service:elemental_ir_emitter",
- "//tensorflow/compiler/xla/service:hlo",
- "//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
- "@llvm//:core",
- ],
-)
-
-cc_library(
name = "ir_emission_utils",
srcs = ["ir_emission_utils.cc"],
hdrs = ["ir_emission_utils.h"],
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc
index 511f89144a..902309b338 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc
@@ -50,14 +50,6 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
return false;
}
- // Producer or consumer cannot be Map. Maps are technically elementwise but
- // of a slightly different form (call instead of a computation). These are not
- // yet supported in the CPU backend.
- if (producer->opcode() == HloOpcode::kMap ||
- consumer->opcode() == HloOpcode::kMap) {
- return false;
- }
-
// Cost condition: not fuse (simple, expensive producers) and (consumers who
// reuse operand elements).
if (producer->opcode() != HloOpcode::kFusion &&
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
index 0fc62281a0..b56466d5e4 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
@@ -209,6 +209,31 @@ class OpcodeFusionTest : public InstructionFusionTest {
std::multiset<HloOpcode>(fused_opcodes.begin(), fused_opcodes.end()),
expected_opcodes);
}
+
+ HloComputation* CreateAdderToOne(HloModule* module) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* arg0 =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {}), "arg0"));
+ HloInstruction* one = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, arg0, one));
+ return module->AddEmbeddedComputation(builder.Build());
+ }
+
+ HloComputation* CreateMax(HloModule* module) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* arg0 =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {}), "arg0"));
+ HloInstruction* arg1 =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 1, ShapeUtil::MakeShape(F32, {}), "arg1"));
+ builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(F32, {}), HloOpcode::kMaximum, arg0, arg1));
+ return module->AddEmbeddedComputation(builder.Build());
+ }
};
TEST_F(OpcodeFusionTest, Exponential_Bitcast_Negate) {
@@ -402,6 +427,49 @@ TEST_F(OpcodeFusionTest, Exponential_Transpose_Negate) {
HloOpcode::kParameter});
}
+TEST_F(OpcodeFusionTest, UnaryMapOfExp) {
+ auto module = CreateNewModule();
+
+ HloComputation::Builder builder(TestName());
+ Shape shape = ShapeUtil::MakeShape(F32, {3, 4});
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, shape, "param"));
+
+ HloInstruction* exp = builder.AddInstruction(
+ HloInstruction::CreateUnary(shape, HloOpcode::kExp, param0));
+ builder.AddInstruction(HloInstruction::CreateMap(
+ shape, {exp}, CreateAdderToOne(module.get()), /*static_operands=*/{}));
+
+ module->AddEntryComputation(builder.Build());
+
+ RunFusionAndCheckOpcodesWereFused(
+ module.get(), {HloOpcode::kParameter, HloOpcode::kExp, HloOpcode::kMap});
+}
+
+TEST_F(OpcodeFusionTest, BinaryMapOfExps) {
+ auto module = CreateNewModule();
+
+ HloComputation::Builder builder(TestName());
+ Shape shape = ShapeUtil::MakeShape(F32, {3, 4});
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, shape, "param"));
+ HloInstruction* param1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, shape, "param"));
+
+ HloInstruction* exp0 = builder.AddInstruction(
+ HloInstruction::CreateUnary(shape, HloOpcode::kExp, param0));
+ HloInstruction* exp1 = builder.AddInstruction(
+ HloInstruction::CreateUnary(shape, HloOpcode::kExp, param1));
+
+ builder.AddInstruction(HloInstruction::CreateMap(
+ shape, {exp0, exp1}, CreateMax(module.get()), /*static_operands=*/{}));
+
+ module->AddEntryComputation(builder.Build());
+
+ RunFusionAndCheckOpcodesWereFused(
+ module.get(), {HloOpcode::kParameter, HloOpcode::kParameter,
+ HloOpcode::kExp, HloOpcode::kExp, HloOpcode::kMap});
+}
} // namespace
} // namespace cpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
index fe447adf89..73e039250b 100644
--- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
@@ -64,5 +64,25 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitFloatUnaryOp(
}
}
+llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator(
+ const HloInstruction* hlo,
+ const HloToElementGeneratorMap& operand_to_generator) const {
+ if (hlo->opcode() == HloOpcode::kMap) {
+ return [this, hlo, &operand_to_generator](
+ const llvm_ir::IrArray::Index& index) -> StatusOr<llvm::Value*> {
+ std::vector<llvm::Value*> operands;
+ for (int i = 0; i < hlo->operand_count(); i++) {
+ TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
+ operand_to_generator.at(hlo->operand(i))(
+ ElementwiseSourceIndex(index, *hlo, 0)));
+ operands.push_back(operand_value);
+ }
+ return ir_emitter_->EmitScalarCall(hlo->shape().element_type(),
+ hlo->to_apply(), operands,
+ llvm_ir::IrName(hlo));
+ };
+ }
+ return ElementalIrEmitter::MakeElementGenerator(hlo, operand_to_generator);
+}
} // namespace cpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h
index 6f9d6a24b4..7e9f27befb 100644
--- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h
@@ -19,6 +19,7 @@ limitations under the License.
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Value.h"
+#include "tensorflow/compiler/xla/service/cpu/ir_emitter.h"
#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -29,12 +30,19 @@ namespace cpu {
class CpuElementalIrEmitter : public ElementalIrEmitter {
public:
CpuElementalIrEmitter(const HloModuleConfig& module_config,
- llvm::IRBuilder<>* ir_builder, llvm::Module* module)
- : ElementalIrEmitter(module_config, module, ir_builder) {}
+ IrEmitter* ir_emitter, llvm::Module* module)
+ : ElementalIrEmitter(module_config, module, ir_emitter->ir_builder()),
+ ir_emitter_(ir_emitter) {}
+
+ llvm_ir::ElementGenerator MakeElementGenerator(
+ const HloInstruction* hlo,
+ const HloToElementGeneratorMap& operand_to_generator) const override;
protected:
StatusOr<llvm::Value*> EmitFloatUnaryOp(
const HloInstruction* op, llvm::Value* operand_value) const override;
+
+ IrEmitter* ir_emitter_;
};
} // namespace cpu
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 7d82f33152..8cd8740ee8 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -2354,8 +2354,7 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) {
for (HloInstruction* operand : fusion->operands()) {
parameter_arrays.push_back(GetIrArrayForOp(operand));
}
- CpuElementalIrEmitter elemental_emitter(hlo_module_config_, &ir_builder_,
- module_);
+ CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter);
TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter));
@@ -3176,12 +3175,27 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) {
return GetIrArrayForOp(operand).EmitReadArrayElement(index, &ir_builder_);
};
}
- CpuElementalIrEmitter elemental_emitter(hlo_module_config_, &ir_builder_,
- module_);
+ CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
return EmitTargetElementLoop(
hlo, elemental_emitter.MakeElementGenerator(hlo, operand_to_generator));
}
+StatusOr<llvm::Value*> IrEmitter::EmitScalarCall(
+ PrimitiveType return_type, HloComputation* computation,
+ const std::vector<llvm::Value*>& arguments, tensorflow::StringPiece name) {
+ llvm::Function* llvm_function = FindOrDie(emitted_functions_, computation);
+ std::vector<llvm::Value*> argument_addrs;
+ for (auto argument : arguments) {
+ llvm::Value* argument_addr = llvm_ir::EmitAllocaAtFunctionEntry(
+ argument->getType(), "arg_addr", &ir_builder_);
+ ir_builder_.CreateStore(argument, argument_addr);
+ argument_addrs.push_back(argument_addr);
+ }
+ return EmitElementFunctionCall(llvm_function,
+ ShapeUtil::MakeShape(return_type, {}),
+ argument_addrs, name);
+}
+
unsigned TargetMachineFeatures::largest_register_size_in_bytes(
llvm::Function* function) {
auto itr = largest_register_size_in_bytes_.find(function);
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index bcd33c3810..fa33a1eb7b 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -133,6 +133,13 @@ class IrEmitter : public DfsHloVisitorWithDefault {
bool is_top_level_computation,
std::vector<const HloInstruction*>* instruction_order);
+ llvm::IRBuilder<>* ir_builder() { return &ir_builder_; }
+
+ // Emits a call to `computation` with scalar arguments `arguments`.
+ StatusOr<llvm::Value*> EmitScalarCall(
+ PrimitiveType return_type, HloComputation* computation,
+ const std::vector<llvm::Value*>& arguments, tensorflow::StringPiece name);
+
protected:
//
// The following methods implement the DfsHloVisitor interface.