aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar David Majnemer <majnemer@google.com>2018-09-02 16:09:34 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-02 16:13:34 -0700
commitcaf4645b1e73e3e1111fb72342a5ade835ac1bf1 (patch)
tree22ef702eb58b466bd7ee24a4087451a08a732500
parent201be3d514d7239aa19496dba4dd0c85303b03f1 (diff)
[XLA] Simplify effective scalar iota to zero
Happened to observe this come up in a linear algebra workload. PiperOrigin-RevId: 211290278
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc15
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc33
2 files changed, 42 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 95e554c9a5..7c078f07d7 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -127,6 +127,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
Status HandleImag(HloInstruction* imag) override;
+ Status HandleIota(HloInstruction* instruction) override;
+
Status HandleConvolution(HloInstruction* convolution) override;
Status HandleDivide(HloInstruction* divide) override;
@@ -1462,6 +1464,19 @@ Status AlgebraicSimplifierVisitor::HandleImag(HloInstruction* imag) {
return Status::OK();
}
+Status AlgebraicSimplifierVisitor::HandleIota(HloInstruction* instruction) {
+ // iota -> zero if the iota dimension never produces an element other than
+ // zero.
+ auto* iota = Cast<HloIotaInstruction>(instruction);
+ if (iota->shape().dimensions(iota->iota_dimension()) <= 1) {
+ auto zero = computation_->AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::Zero(iota->shape().element_type()).CloneToUnique()));
+ return ReplaceWithNewInstruction(
+ iota, HloInstruction::CreateBroadcast(iota->shape(), zero, {}));
+ }
+ return Status::OK();
+}
+
Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) {
if (ShapeUtil::IsZeroElementArray(pad->operand(0)->shape())) {
return ReplaceWithNewInstruction(
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index b4ff048db0..43a891e4fa 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -1858,12 +1858,33 @@ TEST_F(AlgebraicSimplifierTest, IotaAndReshapeMerged) {
ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape));
}
-TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x1_3) {
+TEST_F(AlgebraicSimplifierTest, IotaEffectiveScalar) {
HloComputation::Builder builder(TestName());
auto iota = builder.AddInstruction(
- HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 1}), 1));
+ HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {1, 1}), 0));
+ auto result_shape = iota->shape();
+
+ auto computation = module().AddEntryComputation(builder.Build());
+
+ EXPECT_THAT(computation->root_instruction(), op::Iota());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
+
+ auto root = computation->root_instruction();
+ EXPECT_THAT(root, op::Broadcast(op::Constant()));
+ EXPECT_EQ(0.0f, root->operand(0)->literal().GetFirstElement<float>());
+ EXPECT_TRUE(
+ ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape));
+}
+
+TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2_6) {
+ HloComputation::Builder builder(TestName());
+ auto iota = builder.AddInstruction(
+ HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2}), 1));
builder.AddInstruction(
- HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {3}), iota));
+ HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6}), iota));
auto computation = module().AddEntryComputation(builder.Build());
@@ -1897,12 +1918,12 @@ TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4_6x1x1x4) {
3);
}
-TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2x1_6x1x1x1) {
+TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2x2_6x1x1x2) {
HloComputation::Builder builder(TestName());
auto iota = builder.AddInstruction(
- HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 1}), 2));
+ HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 2}), 2));
builder.AddInstruction(HloInstruction::CreateReshape(
- ShapeUtil::MakeShape(F32, {6, 1, 1, 1}), iota));
+ ShapeUtil::MakeShape(F32, {6, 1, 1, 2}), iota));
HloComputation* computation = module().AddEntryComputation(builder.Build());