diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-02-27 11:23:32 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-02-27 12:06:51 -0800 |
commit | 332ee5051fc38babb53cc8cbf3c3120e5651f4e8 (patch) | |
tree | 91be33183aac60063e973459757b4002b52c6ecc /tensorflow/compiler/xla/service/reshape_mover_test.cc | |
parent | b4d091d5a372f97af48192cb431985b20b447158 (diff) |
[XLA] Set layout of fused instructions in `ReshapeMover`
Change: 148671742
Diffstat (limited to 'tensorflow/compiler/xla/service/reshape_mover_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/reshape_mover_test.cc | 25 |
1 files changed, 25 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc index 850295c726..5028300ecf 100644 --- a/tensorflow/compiler/xla/service/reshape_mover_test.cc +++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc @@ -53,5 +53,30 @@ TEST_F(ReshapeMoverTest, ReshapesWithNonSameInputShapesNotMoved) { EXPECT_EQ(add4, computation->root_instruction()); } +TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossFusion) { + auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); + HloComputation::Builder builder(TestName()); + auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param0")); + auto param1 = builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param0")); + auto reshape2 = + builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0)); + auto reshape3 = + builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param1)); + auto add4 = builder.AddInstruction(HloInstruction::CreateBinary( + root_shape, HloOpcode::kAdd, reshape2, reshape3)); + + auto module = MakeUnique<HloModule>(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + auto fusion = computation->AddInstruction(HloInstruction::CreateFusion( + add4->shape(), HloInstruction::FusionKind::kLoop, add4)); + TF_CHECK_OK(computation->ReplaceInstruction(add4, fusion)); + EXPECT_EQ(fusion, computation->root_instruction()); + EXPECT_TRUE(ReshapeMover().Run(module.get()).ValueOrDie()); + EXPECT_NE(fusion, computation->root_instruction()); + EXPECT_EQ(HloOpcode::kReshape, computation->root_instruction()->opcode()); +} + } // namespace } // namespace xla |