diff options
author | David Majnemer <majnemer@google.com> | 2018-09-02 16:09:34 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-02 16:13:34 -0700 |
commit | caf4645b1e73e3e1111fb72342a5ade835ac1bf1 (patch) | |
tree | 22ef702eb58b466bd7ee24a4087451a08a732500 | |
parent | 201be3d514d7239aa19496dba4dd0c85303b03f1 (diff) |
[XLA] Simplify effective scalar iota to zero
Happened to observe this come up in a linear algebra workload.
PiperOrigin-RevId: 211290278
-rw-r--r-- | tensorflow/compiler/xla/service/algebraic_simplifier.cc | 15 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/algebraic_simplifier_test.cc | 33 |
2 files changed, 42 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 95e554c9a5..7c078f07d7 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -127,6 +127,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleImag(HloInstruction* imag) override; + Status HandleIota(HloInstruction* instruction) override; + Status HandleConvolution(HloInstruction* convolution) override; Status HandleDivide(HloInstruction* divide) override; @@ -1462,6 +1464,19 @@ Status AlgebraicSimplifierVisitor::HandleImag(HloInstruction* imag) { return Status::OK(); } +Status AlgebraicSimplifierVisitor::HandleIota(HloInstruction* instruction) { + // iota -> zero if the iota dimension never produces an element other than + // zero. + auto* iota = Cast<HloIotaInstruction>(instruction); + if (iota->shape().dimensions(iota->iota_dimension()) <= 1) { + auto zero = computation_->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(iota->shape().element_type()).CloneToUnique())); + return ReplaceWithNewInstruction( + iota, HloInstruction::CreateBroadcast(iota->shape(), zero, {})); + } + return Status::OK(); +} + Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { if (ShapeUtil::IsZeroElementArray(pad->operand(0)->shape())) { return ReplaceWithNewInstruction( diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index b4ff048db0..43a891e4fa 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -1858,12 +1858,33 @@ TEST_F(AlgebraicSimplifierTest, IotaAndReshapeMerged) { ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape)); } -TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x1_3) { +TEST_F(AlgebraicSimplifierTest, IotaEffectiveScalar) { HloComputation::Builder builder(TestName()); auto iota = builder.AddInstruction( - HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 1}), 1)); + HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {1, 1}), 0)); + auto result_shape = iota->shape(); + + auto computation = module().AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Iota()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + + auto root = computation->root_instruction(); + EXPECT_THAT(root, op::Broadcast(op::Constant())); + EXPECT_EQ(0.0f, root->operand(0)->literal().GetFirstElement<float>()); + EXPECT_TRUE( + ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape)); +} + +TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2_6) { + HloComputation::Builder builder(TestName()); + auto iota = builder.AddInstruction( + HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2}), 1)); builder.AddInstruction( - HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {3}), iota)); + HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6}), iota)); auto computation = module().AddEntryComputation(builder.Build()); @@ -1897,12 +1918,12 @@ TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4_6x1x1x4) { 3); } -TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2x1_6x1x1x1) { +TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2x2_6x1x1x2) { HloComputation::Builder builder(TestName()); auto iota = builder.AddInstruction( - HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 1}), 2)); + HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 2}), 2)); builder.AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::MakeShape(F32, {6, 1, 1, 1}), iota)); + ShapeUtil::MakeShape(F32, {6, 1, 1, 2}), iota)); HloComputation* computation = module().AddEntryComputation(builder.Build()); |