aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/reshape_mover_test.cc
diff options
context:
space:
mode:
authorGravatar Kay Zhu <kayzhu@google.com>2017-05-24 17:24:32 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-24 17:28:21 -0700
commitbd9c7ddf72ff5fe989fec26db153b91675911b5b (patch)
tree9a9d7fe20a3fc20f417a4b732b42ae18b1941d7f /tensorflow/compiler/xla/service/reshape_mover_test.cc
parent36e942a1c36be0fbc72b31fe0780bad9475c1f6e (diff)
[XLA] In ReshapeMover:
- change OperandCanTriviallyChangeShape to only look at operand's opcode, instead of comparing the dimension of the shape between the instruction and its operand. - remove accounting for non-trivial reshapes/transposes, and instead move the requirement to FirstNonScalarReshapeOperand instead. This fixes an issue where a elementwise op's reshape0(constant0) operand gets moved (because reshape0 and constant0 are of different shapes), even when there is no benefit is doing so. Also fixes incorrect test comments. PiperOrigin-RevId: 157060319
Diffstat (limited to 'tensorflow/compiler/xla/service/reshape_mover_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/reshape_mover_test.cc54
1 files changed, 44 insertions, 10 deletions
diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc
index 895ea107a8..7d8b462279 100644
--- a/tensorflow/compiler/xla/service/reshape_mover_test.cc
+++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc
@@ -62,6 +62,45 @@ TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) {
op::Add(op::Reshape(param0), op::Reshape(param1)));
}
+// For a graph that looks like:
+//
+// +- reshape0 - rng0
+// |
+// +- const1
+// |
+// add
+//
+// where rng0 has a different shape than reshape0.
+//
+// Verifies that the reshape is not moved, since rng0 is trivially reshapable
+// and therefore there is no nontrivial reshapes to move.
+TEST_F(ReshapeMoverTest, 1ConstantAnd1ReshapesOnRngNotMoved) {
+ HloComputation::Builder builder(TestName());
+ auto root_shape = ShapeUtil::MakeShape(F32, {8, 7});
+ auto rng0 = builder.AddInstruction(
+ HloInstruction::CreateRng(ShapeUtil::MakeShape(F32, {1, 8, 1, 7, 1}),
+ RandomDistribution::RNG_UNIFORM, {}));
+ auto reshape0 =
+ builder.AddInstruction(HloInstruction::CreateReshape(root_shape, rng0));
+
+ auto const1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateFromShape(root_shape)));
+
+ builder.AddInstruction(HloInstruction::CreateBinary(
+ root_shape, HloOpcode::kAdd, reshape0, const1));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_THAT(computation->root_instruction(),
+ op::Add(op::Reshape(rng0), const1));
+
+ EXPECT_FALSE(ReshapeMover().Run(module.get()).ValueOrDie());
+
+ EXPECT_THAT(computation->root_instruction(),
+ op::Add(op::Reshape(rng0), const1));
+}
+
TEST_F(ReshapeMoverTest, ScalarReshapesNotMoved) {
HloComputation::Builder builder(TestName());
auto root_shape = ShapeUtil::MakeShape(F32, {});
@@ -181,13 +220,8 @@ TEST_F(ReshapeMoverTest, 1ConstantAnd2ReshapesMoved) {
// |
// add
//
-// Verifies that the reshape0 *does not* unnecessarily sink below add:
-//
-// +- reshape0 - param0
-// |
-// +- param1
-// |
-// add
+// Verifies that the reshape0 does not sink below add, because param1 is not
+// trivially reshapable nor is a Reshape/Transpose.
TEST_F(ReshapeMoverTest, 1ParameterAnd1ReshapeNotMoved) {
HloComputation::Builder builder(TestName());
auto root_shape = ShapeUtil::MakeShape(F32, {8, 7});
@@ -215,12 +249,12 @@ TEST_F(ReshapeMoverTest, 1ParameterAnd1ReshapeNotMoved) {
// For a graph that looks like:
//
+// +- pred
+// |
// +- reshape0 - const0
// |
// +- reshape1 - const1
// |
-// +- param1
-// |
// select
//
// Verifies that we don't unnecessarily sink reshapes, which are in fact
@@ -278,7 +312,7 @@ TEST_F(ReshapeMoverTest, 2TrivialConstantReshapeNotMoved) {
// reshape2
//
// (note that reshape1 here is trivial).
-TEST_F(ReshapeMoverTest, 1NonTrivialReshapeNotMoved) {
+TEST_F(ReshapeMoverTest, 1NonTrivialReshapeMoved) {
HloComputation::Builder builder(TestName());
auto root_shape = ShapeUtil::MakeShape(F32, {2, 3});
auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(