aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc15
-rw-r--r--tensorflow/compiler/xla/service/reshape_mover.cc20
-rw-r--r--tensorflow/compiler/xla/service/reshape_mover_test.cc51
3 files changed, 68 insertions, 18 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 4c058484b9..415aafe69a 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -1348,13 +1348,14 @@ Status AlgebraicSimplifierVisitor::HandleMinimum(HloInstruction* minimum,
StatusOr<bool> AlgebraicSimplifier::Run(HloModule* module) {
XLA_VLOG_LINES(2,
"AlgebraicSimplifier::Run(), before:\n" + module->ToString());
- bool changed =
- std::any_of(module->computations().begin(), module->computations().end(),
- [=](const std::unique_ptr<HloComputation>& computation) {
- return AlgebraicSimplifierVisitor::Run(
- computation.get(), is_layout_sensitive_,
- valid_bitcast_callback_, enable_dot_simplification_);
- });
+ bool changed = false;
+ for (auto& comp : module->computations()) {
+ if (AlgebraicSimplifierVisitor::Run(comp.get(), is_layout_sensitive_,
+ valid_bitcast_callback_,
+ enable_dot_simplification_)) {
+ changed = true;
+ }
+ }
XLA_VLOG_LINES(2,
"AlgebraicSimplifier::Run(), after:\n" + module->ToString());
return changed;
diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc
index 3bff35544c..b72ef95a6a 100644
--- a/tensorflow/compiler/xla/service/reshape_mover.cc
+++ b/tensorflow/compiler/xla/service/reshape_mover.cc
@@ -234,17 +234,15 @@ bool TrySinkReshapeOrTranspose(HloComputation* computation,
} // namespace
StatusOr<bool> ReshapeMover::Run(HloModule* module) {
- return std::any_of(
- module->computations().begin(), module->computations().end(),
- [](const std::unique_ptr<HloComputation>& computation) {
- std::list<HloInstruction*> postorder =
- computation->MakeInstructionPostOrder();
- return std::any_of(postorder.begin(), postorder.end(),
- [&computation](HloInstruction* instruction) {
- return TrySinkReshapeOrTranspose(computation.get(),
- instruction);
- });
- });
+ bool changed = false;
+ for (const auto& comp : module->computations()) {
+ for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) {
+ if (TrySinkReshapeOrTranspose(comp.get(), instruction)) {
+ changed = true;
+ }
+ }
+ }
+ return changed;
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc
index 1862e2e992..09a673ea80 100644
--- a/tensorflow/compiler/xla/service/reshape_mover_test.cc
+++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc
@@ -202,5 +202,56 @@ TEST_F(ReshapeMoverTest, ScalarReshapeNotMovedAcrossSelect) {
EXPECT_EQ(select, computation->root_instruction());
}
+// Tree looks like this:
+//
+// add1
+// |
+// +- reshape2 - param2
+// |
+// +- reshape3 - add0
+// |
+// + reshape0 - param0
+// |
+// + reshape1 - param1
+//
+// We expect reshape{0,1} AND reshape{2,3} to be lifted.
+TEST_F(ReshapeMoverTest, MultiplePasses) {
+ auto shape1 = ShapeUtil::MakeShape(F32, {1, 8, 1, 7});
+ auto shape2 = ShapeUtil::MakeShape(F32, {8, 7, 1});
+ auto shape3 = ShapeUtil::MakeShape(F32, {8, 7});
+ HloComputation::Builder builder(TestName());
+ auto param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, shape1, "param0"));
+ auto param1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, shape1, "param1"));
+ auto param2 = builder.AddInstruction(
+ HloInstruction::CreateParameter(2, shape2, "param2"));
+ auto reshape0 =
+ builder.AddInstruction(HloInstruction::CreateReshape(shape2, param0));
+ auto reshape1 =
+ builder.AddInstruction(HloInstruction::CreateReshape(shape2, param1));
+ auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
+ shape2, HloOpcode::kAdd, reshape0, reshape1));
+ auto reshape2 =
+ builder.AddInstruction(HloInstruction::CreateReshape(shape3, param2));
+ auto reshape3 =
+ builder.AddInstruction(HloInstruction::CreateReshape(shape3, add0));
+ auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
+ shape3, HloOpcode::kAdd, reshape2, reshape3));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+ EXPECT_EQ(add1, computation->root_instruction());
+ EXPECT_TRUE(ReshapeMover().Run(module.get()).ValueOrDie());
+ EXPECT_EQ(HloOpcode::kReshape, computation->root_instruction()->opcode());
+ EXPECT_EQ(HloOpcode::kAdd,
+ computation->root_instruction()->operand(0)->opcode());
+ const auto& add_params =
+ computation->root_instruction()->operand(0)->operands();
+ EXPECT_EQ(2, add_params.size());
+ EXPECT_EQ(HloOpcode::kParameter, add_params[0]->opcode());
+ EXPECT_EQ(HloOpcode::kReshape, add_params[1]->opcode());
+}
+
} // namespace
} // namespace xla