aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar David Majnemer <majnemer@google.com>2018-08-28 17:08:01 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-28 17:12:30 -0700
commitde25baba51a7369bbdc5fd8051af0e9deb6a268d (patch)
tree4dd13806044b9e80128164ecaaf0d13ec2548173
parent4dcd00066fba2bd7c504c1bc35738f804de9df67 (diff)
[XLA] Add support for algebraic simplifications involving kIota
PiperOrigin-RevId: 210634966
-rw-r--r--tensorflow/compiler/xla/service/BUILD2
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc34
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc142
-rw-r--r--tensorflow/compiler/xla/service/hlo_matchers.h1
-rw-r--r--tensorflow/compiler/xla/service/pattern_matcher.h1
5 files changed, 175 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index b68785949c..4aef093b04 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -1350,6 +1350,7 @@ cc_library(
hdrs = ["algebraic_simplifier.h"],
deps = [
":hlo",
+ ":hlo_casting_utils",
":hlo_creation_utils",
":hlo_pass",
":hlo_query",
@@ -1376,6 +1377,7 @@ tf_cc_test(
deps = [
":algebraic_simplifier",
":hlo",
+ ":hlo_casting_utils",
":hlo_matchers",
":hlo_pass",
"//tensorflow/compiler/xla:literal",
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index c236453fc7..19bb4da9a6 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -30,9 +30,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_query.h"
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
@@ -1238,7 +1240,7 @@ namespace {
// return value = {1, 3}
//
// Precondition: input_dim_indices is sorted.
-std::pair<bool, std::vector<int64>> ReshapeLeavesDimensionsUnmodified(
+absl::optional<std::vector<int64>> ReshapeLeavesDimensionsUnmodified(
const HloInstruction* hlo,
tensorflow::gtl::ArraySlice<int64> input_dim_indices) {
CHECK_EQ(HloOpcode::kReshape, hlo->opcode());
@@ -1258,11 +1260,11 @@ std::pair<bool, std::vector<int64>> ReshapeLeavesDimensionsUnmodified(
}
if (i >= unmodified_dims.size() ||
unmodified_dims[i].first != input_dim_index) {
- return std::make_pair(false, std::vector<int64>());
+ return absl::nullopt;
}
output_dim_indices.push_back(unmodified_dims[i].second);
}
- return std::make_pair(true, output_dim_indices);
+ return output_dim_indices;
}
// Returns true if the output of "instruction" is a permutation of the
@@ -1391,6 +1393,15 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
return Status::OK();
}
+ // broadcast(iota) -> iota.
+ if (operand->opcode() == HloOpcode::kIota) {
+ return ReplaceWithNewInstruction(
+ broadcast,
+ HloInstruction::CreateIota(
+ broadcast->shape(),
+ dims[Cast<HloIotaInstruction>(operand)->iota_dimension()]));
+ }
+
// Merge two consecutive broadcasts into a single one.
if (operand->opcode() == HloOpcode::kBroadcast) {
std::vector<int64> new_dimensions;
@@ -1719,12 +1730,25 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) {
if (HloOpcode::kBroadcast == reshape->operand(0)->opcode()) {
auto opt_dims = ReshapeLeavesDimensionsUnmodified(
reshape, reshape->operand(0)->dimensions());
- if (opt_dims.first) {
+ if (opt_dims.has_value()) {
return ReplaceWithNewInstruction(
reshape,
HloInstruction::CreateBroadcast(
reshape->shape(), reshape->mutable_operand(0)->mutable_operand(0),
- opt_dims.second));
+ *opt_dims));
+ }
+ }
+
+ // reshape(iota) -> iota.
+ if (operand->opcode() == HloOpcode::kIota) {
+ auto* iota = Cast<HloIotaInstruction>(operand);
+ auto opt_dims =
+ ReshapeLeavesDimensionsUnmodified(reshape, {iota->iota_dimension()});
+ if (opt_dims.has_value()) {
+ CHECK_EQ(opt_dims->size(), 1);
+ return ReplaceWithNewInstruction(
+ reshape,
+ HloInstruction::CreateIota(reshape->shape(), opt_dims->front()));
}
}
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index bb63ea26d4..1900a05750 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -23,8 +23,10 @@ limitations under the License.
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
@@ -1826,6 +1828,105 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) {
op::Reshape(op::Broadcast(param)));
}
+TEST_F(AlgebraicSimplifierTest, IotaAndReshapeMerged) {
+ HloComputation::Builder builder(TestName());
+ auto iota = builder.AddInstruction(HloInstruction::CreateIota(
+ ShapeUtil::MakeShape(F32, {1, 2, 3, 7, 12, 1}), 2));
+ Shape result_shape = ShapeUtil::MakeShape(F32, {2, 3, 7, 2, 1, 3, 2});
+ builder.AddInstruction(HloInstruction::CreateReshape(result_shape, iota));
+
+ auto computation = module().AddEntryComputation(builder.Build());
+
+ EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota()));
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
+
+ EXPECT_THAT(computation->root_instruction(), op::Iota());
+ EXPECT_TRUE(
+ ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape));
+}
+
+TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x1_3) {
+ HloComputation::Builder builder(TestName());
+ auto iota = builder.AddInstruction(
+ HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 1}), 1));
+ builder.AddInstruction(
+ HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {3}), iota));
+
+ auto computation = module().AddEntryComputation(builder.Build());
+
+ EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota()));
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie());
+
+ EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota()));
+}
+
+TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4_6x1x1x4) {
+ HloComputation::Builder builder(TestName());
+ auto iota = builder.AddInstruction(
+ HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 4}), 2));
+ builder.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {6, 1, 1, 4}), iota));
+
+ HloComputation* computation = module().AddEntryComputation(builder.Build());
+
+ EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota()));
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
+
+ EXPECT_THAT(computation->root_instruction(), op::Iota());
+ EXPECT_EQ(Cast<HloIotaInstruction>(computation->root_instruction())
+ ->iota_dimension(),
+ 3);
+}
+
+TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2x1_6x1x1x1) {
+ HloComputation::Builder builder(TestName());
+ auto iota = builder.AddInstruction(
+ HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 1}), 2));
+ builder.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {6, 1, 1, 1}), iota));
+
+ HloComputation* computation = module().AddEntryComputation(builder.Build());
+
+ EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota()));
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
+
+ EXPECT_THAT(computation->root_instruction(), op::Iota());
+ const int64 iota_dim =
+ Cast<HloIotaInstruction>(computation->root_instruction())
+ ->iota_dimension();
+ EXPECT_THAT(iota_dim, ::testing::AnyOf(1, 2, 3));
+}
+
+TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4x2_6x8) {
+ HloComputation::Builder builder(TestName());
+ auto iota = builder.AddInstruction(
+ HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 4, 2}), 2));
+ builder.AddInstruction(
+ HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6, 8}), iota));
+
+ HloComputation* computation = module().AddEntryComputation(builder.Build());
+
+ EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota()));
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie());
+
+ EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota()));
+}
+
TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) {
HloComputation::Builder builder(TestName());
HloInstruction* param =
@@ -2653,6 +2754,47 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcasts2) {
EXPECT_THAT(root->dimensions(), ElementsAre(1, 3));
}
+// Test that a broadcast of an iota can be merged to one iota.
+TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota) {
+ HloComputation::Builder builder(TestName());
+ Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
+ HloInstruction* iota =
+ builder.AddInstruction(HloInstruction::CreateIota(r2f32, 1));
+ Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 2, 2});
+ builder.AddInstruction(HloInstruction::CreateBroadcast(r3f32, iota, {0, 2}));
+
+ auto computation = module().AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
+ root = computation->root_instruction();
+ EXPECT_THAT(root, op::Iota());
+ EXPECT_EQ(Cast<HloIotaInstruction>(root)->iota_dimension(), 2);
+}
+
+// Test that a broadcast of an iota can be merged to one iota.
+TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota2) {
+ HloComputation::Builder builder(TestName());
+ Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 5, 3});
+ HloInstruction* iota =
+ builder.AddInstruction(HloInstruction::CreateIota(r3f32, 1));
+ Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 2, 5, 3});
+ builder.AddInstruction(
+ HloInstruction::CreateBroadcast(r4f32, iota, {1, 2, 3}));
+
+ auto computation = module().AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
+ root = computation->root_instruction();
+ EXPECT_THAT(root, op::Iota());
+ EXPECT_EQ(Cast<HloIotaInstruction>(root)->iota_dimension(), 2);
+}
+
struct PadReduceWindowEffectiveBroadcastCase {
std::vector<int64> input_spatials;
std::vector<int64> symmetric_pad_spatials;
diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h
index 9ace0d76e0..5502e565b6 100644
--- a/tensorflow/compiler/xla/service/hlo_matchers.h
+++ b/tensorflow/compiler/xla/service/hlo_matchers.h
@@ -188,6 +188,7 @@ HLO_MATCHER(Fusion);
HLO_MATCHER(Ge);
HLO_MATCHER(AfterAll);
HLO_MATCHER(Gt);
+HLO_MATCHER(Iota);
HLO_MATCHER(Infeed);
HLO_MATCHER(IsFinite);
HLO_MATCHER(Le);
diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h
index ccc06ce613..4869db79e7 100644
--- a/tensorflow/compiler/xla/service/pattern_matcher.h
+++ b/tensorflow/compiler/xla/service/pattern_matcher.h
@@ -918,6 +918,7 @@ Op(::xla::HloInstruction** matched_inst) {
}
XLA_NULLOP_PATTERN(Constant)
XLA_NULLOP_PATTERN(Parameter)
+XLA_NULLOP_PATTERN(Iota)
#undef XLA_NULLOP_PATTERN
// Helpers for unary instructions.