aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-16 05:11:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-16 05:14:37 -0700
commit15c319f856b205664b2f462e0434ada22037771a (patch)
tree6453ad3cd68a22f3434db84191e0ffbb2a1e124e /tensorflow
parentb0e4e9f5ccbeee2372c9c8ff516b6c5598376bd1 (diff)
Refactor HloInstruction::Fuse and add a method for multi-output fusion.
PiperOrigin-RevId: 196813042
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc23
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.h10
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion_test.cc85
3 files changed, 113 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index 06b84cc145..cb6c98c481 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -414,12 +414,9 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
return changed;
}
-HloInstruction* InstructionFusion::Fuse(HloInstruction* producer,
- HloInstruction* consumer) {
+HloInstruction* InstructionFusion::AddFusionInstruction(
+ HloInstruction* producer, HloInstruction* consumer) {
HloInstruction* fusion_instruction;
-
- VLOG(2) << "Fusing " << producer->ToString() << " into "
- << consumer->ToString();
auto kind = ChooseKind(producer, consumer);
if (consumer->opcode() == HloOpcode::kFusion) {
fusion_instruction = consumer;
@@ -431,11 +428,27 @@ HloInstruction* InstructionFusion::Fuse(HloInstruction* producer,
HloInstruction::CreateFusion(consumer->shape(), kind, consumer));
TF_CHECK_OK(computation_->ReplaceInstruction(consumer, fusion_instruction));
}
+ return fusion_instruction;
+}
+HloInstruction* InstructionFusion::Fuse(HloInstruction* producer,
+ HloInstruction* consumer) {
+ VLOG(2) << "Fusing " << producer->ToString() << " into "
+ << consumer->ToString();
+ HloInstruction* fusion_instruction = AddFusionInstruction(producer, consumer);
fusion_instruction->FuseInstruction(producer);
return fusion_instruction;
}
+HloInstruction* InstructionFusion::FuseIntoMultiOutput(
+ HloInstruction* producer, HloInstruction* consumer) {
+ VLOG(2) << "Multi-output fusing " << producer->ToString() << " into "
+ << consumer->ToString();
+ HloInstruction* fusion_instruction = AddFusionInstruction(producer, consumer);
+ fusion_instruction->FuseInstructionIntoMultiOutput(producer);
+ return fusion_instruction;
+}
+
bool InstructionFusion::ShouldFuse(HloInstruction* consumer,
int64 operand_index) {
HloInstruction* producer = consumer->mutable_operand(operand_index);
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h
index 2ea1fcf937..c3c2ed0aaa 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.h
+++ b/tensorflow/compiler/xla/service/instruction_fusion.h
@@ -70,6 +70,13 @@ class InstructionFusion : public HloPassInterface {
virtual HloInstruction* Fuse(HloInstruction* producer,
HloInstruction* consumer);
+ // Creates a new fusion instruction containing `producer` and `consumer`. A
+ // tuple is added as the fusion instruction's root, which consumes from both,
+ // `producer` and `consumer`. This style of fusion is referred to as
+ // multi-output fusion.
+ virtual HloInstruction* FuseIntoMultiOutput(HloInstruction* producer,
+ HloInstruction* consumer);
+
// An "effectively unary" operation is one that has at most one "large"
// input with the others being negligible in terms of memory usage.
// We use "has a smaller true rank than the output" as a heuristic
@@ -95,6 +102,9 @@ class InstructionFusion : public HloPassInterface {
// The set of producers whose consumers we cannot fuse into.
using DoNotFuseSet = std::unordered_set<HloInstruction*>;
+ HloInstruction* AddFusionInstruction(HloInstruction* producer,
+ HloInstruction* consumer);
+
// Whether or not we can fuse producer into consumer on all paths
// from the producer to the consumer where nodes are HLOs and edges are uses.
bool CanFuseOnAllPaths(HloInstruction* producer, HloInstruction* consumer,
diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc
index cf9673a38a..df109df787 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc
@@ -25,6 +25,91 @@ namespace op = xla::testing::opcode_matchers;
using InstructionFusionTest = HloTestBase;
+// Subclass of InstructionFusion exposing the protected methods Fuse and
+// FuseIntoMultiOutput for testing.
+class InstructionFusionForTesting : public InstructionFusion {
+ public:
+ explicit InstructionFusionForTesting(HloModule* module)
+ : InstructionFusion(InstructionFusion::IsExpensive) {
+ module_ = module;
+ computation_ = module->entry_computation();
+ }
+
+ HloInstruction* Fuse(HloInstruction* producer,
+ HloInstruction* consumer) override {
+ return InstructionFusion::Fuse(producer, consumer);
+ }
+
+ HloInstruction* FuseIntoMultiOutput(HloInstruction* producer,
+ HloInstruction* consumer) override {
+ return InstructionFusion::FuseIntoMultiOutput(producer, consumer);
+ }
+};
+
+TEST_F(InstructionFusionTest, FuseInstructions) {
+ auto module = tools::Parse(R"(
+ HloModule test_module
+ ENTRY entry_computation {
+ p0 = f32[4,3]{1,0} parameter(0)
+ add = f32[4,3]{1,0} add(p0, p0)
+ ROOT sub = f32[4,3]{1,0} subtract(add, p0)
+ })")
+ .ValueOrDie();
+ HloInstruction* sub = module->entry_computation()->root_instruction();
+ HloInstruction* add = sub->mutable_operand(0);
+ HloInstruction* fusion =
+ InstructionFusionForTesting(module.get()).Fuse(add, sub);
+
+ ASSERT_THAT(fusion, op::Fusion()) << module->ToString();
+ EXPECT_THAT(fusion->fused_expression_root(),
+ op::Subtract(op::Add(), op::Parameter()))
+ << module->ToString();
+}
+
+TEST_F(InstructionFusionTest, FuseIntoFusionInstruction) {
+ auto module = tools::Parse(R"(
+ HloModule test_module
+ fused_computation {
+ p1 = f32[4,3] parameter(0)
+ add = f32[4,3] add(p1, p1)
+ }
+ ENTRY entry_computation {
+ p0 = f32[4,3] parameter(0)
+ abs = f32[4,3] abs(p0)
+ ROOT fusion = f32[4,3] fusion(abs), kind=kLoop, calls=fused_computation
+ })")
+ .ValueOrDie();
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ HloInstruction* abs = root->mutable_operand(0);
+ HloInstruction* fusion =
+ InstructionFusionForTesting(module.get()).Fuse(abs, root);
+
+ ASSERT_THAT(fusion, op::Fusion()) << module->ToString();
+ EXPECT_THAT(fusion->fused_expression_root(), op::Add(op::Abs(), op::Abs()))
+ << module->ToString();
+}
+
+TEST_F(InstructionFusionTest, FuseInstructionsIntoMultiOutput) {
+ auto module = tools::Parse(R"(
+ HloModule test_module
+ ENTRY entry_computation {
+ p0 = f32[4,3]{1,0} parameter(0)
+ abs = f32[4,3]{1,0} abs(p0)
+ tanh = f32[4,3]{1,0} tanh(abs)
+ ROOT add = f32[4,3]{1,0} add(abs, tanh)
+ })")
+ .ValueOrDie();
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ HloInstruction* abs = root->mutable_operand(0);
+ HloInstruction* tanh = root->mutable_operand(1);
+ HloInstruction* fusion =
+ InstructionFusionForTesting(module.get()).FuseIntoMultiOutput(abs, tanh);
+
+ ASSERT_THAT(fusion, op::Fusion()) << module->ToString();
+ EXPECT_THAT(fusion->fused_expression_root(), op::Tuple(op::Tanh(), op::Abs()))
+ << module->ToString();
+}
+
TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfParameterUnfused) {
HloComputation::Builder builder(TestName());
auto param0 = builder.AddInstruction(