diff options
-rw-r--r-- | tensorflow/compiler/xla/service/algebraic_simplifier.cc | 15 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/reshape_mover.cc | 20 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/reshape_mover_test.cc | 51 |
3 files changed, 68 insertions, 18 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 4c058484b9..415aafe69a 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -1348,13 +1348,14 @@ Status AlgebraicSimplifierVisitor::HandleMinimum(HloInstruction* minimum, StatusOr<bool> AlgebraicSimplifier::Run(HloModule* module) { XLA_VLOG_LINES(2, "AlgebraicSimplifier::Run(), before:\n" + module->ToString()); - bool changed = - std::any_of(module->computations().begin(), module->computations().end(), - [=](const std::unique_ptr<HloComputation>& computation) { - return AlgebraicSimplifierVisitor::Run( - computation.get(), is_layout_sensitive_, - valid_bitcast_callback_, enable_dot_simplification_); - }); + bool changed = false; + for (auto& comp : module->computations()) { + if (AlgebraicSimplifierVisitor::Run(comp.get(), is_layout_sensitive_, + valid_bitcast_callback_, + enable_dot_simplification_)) { + changed = true; + } + } XLA_VLOG_LINES(2, "AlgebraicSimplifier::Run(), after:\n" + module->ToString()); return changed; diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc index 3bff35544c..b72ef95a6a 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.cc +++ b/tensorflow/compiler/xla/service/reshape_mover.cc @@ -234,17 +234,15 @@ bool TrySinkReshapeOrTranspose(HloComputation* computation, } // namespace StatusOr<bool> ReshapeMover::Run(HloModule* module) { - return std::any_of( - module->computations().begin(), module->computations().end(), - [](const std::unique_ptr<HloComputation>& computation) { - std::list<HloInstruction*> postorder = - computation->MakeInstructionPostOrder(); - return std::any_of(postorder.begin(), postorder.end(), - [&computation](HloInstruction* instruction) { - return TrySinkReshapeOrTranspose(computation.get(), - instruction); - }); - }); + bool changed = false; + for (const auto& comp : module->computations()) { + for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) { + if (TrySinkReshapeOrTranspose(comp.get(), instruction)) { + changed = true; + } + } + } + return changed; } } // namespace xla diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc index 1862e2e992..09a673ea80 100644 --- a/tensorflow/compiler/xla/service/reshape_mover_test.cc +++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc @@ -202,5 +202,56 @@ TEST_F(ReshapeMoverTest, ScalarReshapeNotMovedAcrossSelect) { EXPECT_EQ(select, computation->root_instruction()); } +// Tree looks like this: +// +// add1 +// | +// +- reshape2 - param2 +// | +// +- reshape3 - add0 +// | +// + reshape0 - param0 +// | +// + reshape1 - param1 +// +// We expect reshape{0,1} AND reshape{2,3} to be lifted. +TEST_F(ReshapeMoverTest, MultiplePasses) { + auto shape1 = ShapeUtil::MakeShape(F32, {1, 8, 1, 7}); + auto shape2 = ShapeUtil::MakeShape(F32, {8, 7, 1}); + auto shape3 = ShapeUtil::MakeShape(F32, {8, 7}); + HloComputation::Builder builder(TestName()); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape1, "param0")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, shape1, "param1")); + auto param2 = builder.AddInstruction( + HloInstruction::CreateParameter(2, shape2, "param2")); + auto reshape0 = + builder.AddInstruction(HloInstruction::CreateReshape(shape2, param0)); + auto reshape1 = + builder.AddInstruction(HloInstruction::CreateReshape(shape2, param1)); + auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( + shape2, HloOpcode::kAdd, reshape0, reshape1)); + auto reshape2 = + builder.AddInstruction(HloInstruction::CreateReshape(shape3, param2)); + auto reshape3 = + builder.AddInstruction(HloInstruction::CreateReshape(shape3, add0)); + auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( + shape3, HloOpcode::kAdd, reshape2, reshape3)); + + auto module = MakeUnique<HloModule>(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + EXPECT_EQ(add1, computation->root_instruction()); + EXPECT_TRUE(ReshapeMover().Run(module.get()).ValueOrDie()); + EXPECT_EQ(HloOpcode::kReshape, computation->root_instruction()->opcode()); + EXPECT_EQ(HloOpcode::kAdd, + computation->root_instruction()->operand(0)->opcode()); + const auto& add_params = + computation->root_instruction()->operand(0)->operands(); + EXPECT_EQ(2, add_params.size()); + EXPECT_EQ(HloOpcode::kParameter, add_params[0]->opcode()); + EXPECT_EQ(HloOpcode::kReshape, add_params[1]->opcode()); +} + } // namespace } // namespace xla |