diff options
author | Kay Zhu <kayzhu@google.com> | 2017-05-24 17:24:32 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-05-24 17:28:21 -0700 |
commit | bd9c7ddf72ff5fe989fec26db153b91675911b5b (patch) | |
tree | 9a9d7fe20a3fc20f417a4b732b42ae18b1941d7f /tensorflow/compiler/xla/service/reshape_mover_test.cc | |
parent | 36e942a1c36be0fbc72b31fe0780bad9475c1f6e (diff) |
[XLA] In ReshapeMover:
- change OperandCanTriviallyChangeShape to only look at operand's opcode, instead of comparing the dimension of the shape between the instruction and its operand.
- remove accounting for non-trivial reshapes/transposes, and instead move the requirement to FirstNonScalarReshapeOperand instead.
This fixes an issue where a elementwise op's reshape0(constant0) operand gets moved (because reshape0 and constant0 are of different shapes), even when there is no benefit is doing so.
Also fixes incorrect test comments.
PiperOrigin-RevId: 157060319
Diffstat (limited to 'tensorflow/compiler/xla/service/reshape_mover_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/reshape_mover_test.cc | 54 |
1 files changed, 44 insertions, 10 deletions
diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc index 895ea107a8..7d8b462279 100644 --- a/tensorflow/compiler/xla/service/reshape_mover_test.cc +++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc @@ -62,6 +62,45 @@ TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) { op::Add(op::Reshape(param0), op::Reshape(param1))); } +// For a graph that looks like: +// +// +- reshape0 - rng0 +// | +// +- const1 +// | +// add +// +// where rng0 has a different shape than reshape0. +// +// Verifies that the reshape is not moved, since rng0 is trivially reshapable +// and therefore there is no nontrivial reshapes to move. +TEST_F(ReshapeMoverTest, 1ConstantAnd1ReshapesOnRngNotMoved) { + HloComputation::Builder builder(TestName()); + auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); + auto rng0 = builder.AddInstruction( + HloInstruction::CreateRng(ShapeUtil::MakeShape(F32, {1, 8, 1, 7, 1}), + RandomDistribution::RNG_UNIFORM, {})); + auto reshape0 = + builder.AddInstruction(HloInstruction::CreateReshape(root_shape, rng0)); + + auto const1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateFromShape(root_shape))); + + builder.AddInstruction(HloInstruction::CreateBinary( + root_shape, HloOpcode::kAdd, reshape0, const1)); + + auto module = MakeUnique<HloModule>(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), + op::Add(op::Reshape(rng0), const1)); + + EXPECT_FALSE(ReshapeMover().Run(module.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + op::Add(op::Reshape(rng0), const1)); +} + TEST_F(ReshapeMoverTest, ScalarReshapesNotMoved) { HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {}); @@ -181,13 +220,8 @@ TEST_F(ReshapeMoverTest, 1ConstantAnd2ReshapesMoved) { // | // add // -// Verifies that the reshape0 *does not* unnecessarily sink below add: -// -// +- reshape0 - param0 -// | -// +- param1 -// | -// add +// Verifies that the reshape0 does not sink below add, because param1 is not +// trivially reshapable nor is a Reshape/Transpose. TEST_F(ReshapeMoverTest, 1ParameterAnd1ReshapeNotMoved) { HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); @@ -215,12 +249,12 @@ TEST_F(ReshapeMoverTest, 1ParameterAnd1ReshapeNotMoved) { // For a graph that looks like: // +// +- pred +// | // +- reshape0 - const0 // | // +- reshape1 - const1 // | -// +- param1 -// | // select // // Verifies that we don't unnecessarily sink reshapes, which are in fact @@ -278,7 +312,7 @@ TEST_F(ReshapeMoverTest, 2TrivialConstantReshapeNotMoved) { // reshape2 // // (note that reshape1 here is trivial). -TEST_F(ReshapeMoverTest, 1NonTrivialReshapeNotMoved) { +TEST_F(ReshapeMoverTest, 1NonTrivialReshapeMoved) { HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {2, 3}); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( |