diff options
author | David Majnemer <majnemer@google.com> | 2018-08-28 17:08:01 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-28 17:12:30 -0700 |
commit | de25baba51a7369bbdc5fd8051af0e9deb6a268d (patch) | |
tree | 4dd13806044b9e80128164ecaaf0d13ec2548173 | |
parent | 4dcd00066fba2bd7c504c1bc35738f804de9df67 (diff) |
[XLA] Add support for algebraic simplifications involving kIota
PiperOrigin-RevId: 210634966
-rw-r--r-- | tensorflow/compiler/xla/service/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/algebraic_simplifier.cc | 34 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/algebraic_simplifier_test.cc | 142 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_matchers.h | 1 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/pattern_matcher.h | 1 |
5 files changed, 175 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index b68785949c..4aef093b04 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1350,6 +1350,7 @@ cc_library( hdrs = ["algebraic_simplifier.h"], deps = [ ":hlo", + ":hlo_casting_utils", ":hlo_creation_utils", ":hlo_pass", ":hlo_query", @@ -1376,6 +1377,7 @@ tf_cc_test( deps = [ ":algebraic_simplifier", ":hlo", + ":hlo_casting_utils", ":hlo_matchers", ":hlo_pass", "//tensorflow/compiler/xla:literal", diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index c236453fc7..19bb4da9a6 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -30,9 +30,11 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_query.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" @@ -1238,7 +1240,7 @@ namespace { // return value = {1, 3} // // Precondition: input_dim_indices is sorted. -std::pair<bool, std::vector<int64>> ReshapeLeavesDimensionsUnmodified( +absl::optional<std::vector<int64>> ReshapeLeavesDimensionsUnmodified( const HloInstruction* hlo, tensorflow::gtl::ArraySlice<int64> input_dim_indices) { CHECK_EQ(HloOpcode::kReshape, hlo->opcode()); @@ -1258,11 +1260,11 @@ std::pair<bool, std::vector<int64>> ReshapeLeavesDimensionsUnmodified( } if (i >= unmodified_dims.size() || unmodified_dims[i].first != input_dim_index) { - return std::make_pair(false, std::vector<int64>()); + return absl::nullopt; } output_dim_indices.push_back(unmodified_dims[i].second); } - return std::make_pair(true, output_dim_indices); + return output_dim_indices; } // Returns true if the output of "instruction" is a permutation of the @@ -1391,6 +1393,15 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { return Status::OK(); } + // broadcast(iota) -> iota. + if (operand->opcode() == HloOpcode::kIota) { + return ReplaceWithNewInstruction( + broadcast, + HloInstruction::CreateIota( + broadcast->shape(), + dims[Cast<HloIotaInstruction>(operand)->iota_dimension()])); + } + // Merge two consecutive broadcasts into a single one. if (operand->opcode() == HloOpcode::kBroadcast) { std::vector<int64> new_dimensions; @@ -1719,12 +1730,25 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { if (HloOpcode::kBroadcast == reshape->operand(0)->opcode()) { auto opt_dims = ReshapeLeavesDimensionsUnmodified( reshape, reshape->operand(0)->dimensions()); - if (opt_dims.first) { + if (opt_dims.has_value()) { return ReplaceWithNewInstruction( reshape, HloInstruction::CreateBroadcast( reshape->shape(), reshape->mutable_operand(0)->mutable_operand(0), - opt_dims.second)); + *opt_dims)); + } + } + + // reshape(iota) -> iota. + if (operand->opcode() == HloOpcode::kIota) { + auto* iota = Cast<HloIotaInstruction>(operand); + auto opt_dims = + ReshapeLeavesDimensionsUnmodified(reshape, {iota->iota_dimension()}); + if (opt_dims.has_value()) { + CHECK_EQ(opt_dims->size(), 1); + return ReplaceWithNewInstruction( + reshape, + HloInstruction::CreateIota(reshape->shape(), opt_dims->front())); } } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index bb63ea26d4..1900a05750 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -23,8 +23,10 @@ limitations under the License. #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" @@ -1826,6 +1828,105 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) { op::Reshape(op::Broadcast(param))); } +TEST_F(AlgebraicSimplifierTest, IotaAndReshapeMerged) { + HloComputation::Builder builder(TestName()); + auto iota = builder.AddInstruction(HloInstruction::CreateIota( + ShapeUtil::MakeShape(F32, {1, 2, 3, 7, 12, 1}), 2)); + Shape result_shape = ShapeUtil::MakeShape(F32, {2, 3, 7, 2, 1, 3, 2}); + builder.AddInstruction(HloInstruction::CreateReshape(result_shape, iota)); + + auto computation = module().AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), op::Iota()); + EXPECT_TRUE( + ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape)); +} + +TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x1_3) { + HloComputation::Builder builder(TestName()); + auto iota = builder.AddInstruction( + HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 1}), 1)); + builder.AddInstruction( + HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {3}), iota)); + + auto computation = module().AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); +} + +TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4_6x1x1x4) { + HloComputation::Builder builder(TestName()); + auto iota = builder.AddInstruction( + HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 4}), 2)); + builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {6, 1, 1, 4}), iota)); + + HloComputation* computation = module().AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), op::Iota()); + EXPECT_EQ(Cast<HloIotaInstruction>(computation->root_instruction()) + ->iota_dimension(), + 3); +} + +TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2x1_6x1x1x1) { + HloComputation::Builder builder(TestName()); + auto iota = builder.AddInstruction( + HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 1}), 2)); + builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {6, 1, 1, 1}), iota)); + + HloComputation* computation = module().AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), op::Iota()); + const int64 iota_dim = + Cast<HloIotaInstruction>(computation->root_instruction()) + ->iota_dimension(); + EXPECT_THAT(iota_dim, ::testing::AnyOf(1, 2, 3)); +} + +TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4x2_6x8) { + HloComputation::Builder builder(TestName()); + auto iota = builder.AddInstruction( + HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 4, 2}), 2)); + builder.AddInstruction( + HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6, 8}), iota)); + + HloComputation* computation = module().AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); +} + TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) { HloComputation::Builder builder(TestName()); HloInstruction* param = @@ -2653,6 +2754,47 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcasts2) { EXPECT_THAT(root->dimensions(), ElementsAre(1, 3)); } +// Test that a broadcast of an iota can be merged to one iota. +TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota) { + HloComputation::Builder builder(TestName()); + Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); + HloInstruction* iota = + builder.AddInstruction(HloInstruction::CreateIota(r2f32, 1)); + Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 2, 2}); + builder.AddInstruction(HloInstruction::CreateBroadcast(r3f32, iota, {0, 2})); + + auto computation = module().AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_THAT(root, op::Iota()); + EXPECT_EQ(Cast<HloIotaInstruction>(root)->iota_dimension(), 2); +} + +// Test that a broadcast of an iota can be merged to one iota. +TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota2) { + HloComputation::Builder builder(TestName()); + Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 5, 3}); + HloInstruction* iota = + builder.AddInstruction(HloInstruction::CreateIota(r3f32, 1)); + Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 2, 5, 3}); + builder.AddInstruction( + HloInstruction::CreateBroadcast(r4f32, iota, {1, 2, 3})); + + auto computation = module().AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_THAT(root, op::Iota()); + EXPECT_EQ(Cast<HloIotaInstruction>(root)->iota_dimension(), 2); +} + struct PadReduceWindowEffectiveBroadcastCase { std::vector<int64> input_spatials; std::vector<int64> symmetric_pad_spatials; diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index 9ace0d76e0..5502e565b6 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -188,6 +188,7 @@ HLO_MATCHER(Fusion); HLO_MATCHER(Ge); HLO_MATCHER(AfterAll); HLO_MATCHER(Gt); +HLO_MATCHER(Iota); HLO_MATCHER(Infeed); HLO_MATCHER(IsFinite); HLO_MATCHER(Le); diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h index ccc06ce613..4869db79e7 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher.h +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -918,6 +918,7 @@ Op(::xla::HloInstruction** matched_inst) { } XLA_NULLOP_PATTERN(Constant) XLA_NULLOP_PATTERN(Parameter) +XLA_NULLOP_PATTERN(Iota) #undef XLA_NULLOP_PATTERN // Helpers for unary instructions. |