aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/algebraic_simplifier_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc1368
1 files changed, 1368 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
new file mode 100644
index 0000000000..49ea91f83b
--- /dev/null
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -0,0 +1,1368 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
+
+#include <memory>
+#include <utility>
+
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+
+namespace xla {
+namespace {
+
+AlgebraicSimplifier::ValidBitcastCallback bitcasting_callback() {
+ return [](const Shape&, const Shape&) { return true; };
+}
+AlgebraicSimplifier::ValidBitcastCallback non_bitcasting_callback() {
+ return [](const Shape&, const Shape&) { return false; };
+}
+
+using AlgebraicSimplifierTest = HloTestBase;
+
+// Test that A + 0 is simplified to A
+TEST_F(AlgebraicSimplifierTest, AddZero) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32, "param0"));
+ HloInstruction* zero = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, zero));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ root = computation->root_instruction();
+ EXPECT_EQ(root, param0);
+}
+
+// Test that A - 0 is simplified to A
+TEST_F(AlgebraicSimplifierTest, SubZero) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32, "param0"));
+ HloInstruction* zero = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32, HloOpcode::kSubtract, param0, zero));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kSubtract);
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ root = computation->root_instruction();
+ EXPECT_EQ(root, param0);
+}
+
+// Test that A/1 is simplified to A for a scalar.
+TEST_F(AlgebraicSimplifierTest, DivOneScalar) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32, "param0"));
+ HloInstruction* one = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
+ HloInstruction* div = builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, one));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root, div);
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ root = computation->root_instruction();
+ EXPECT_EQ(root, param0);
+}
+
+// Test that A/1 is simplified to A for an array.
+TEST_F(AlgebraicSimplifierTest, DivOneArray) {
+ Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r2f32, "param0"));
+ HloInstruction* one = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{1.0, 1.0}, {1.0, 1.0}})));
+ HloInstruction* div = builder.AddInstruction(
+ HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, one));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root, div);
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ root = computation->root_instruction();
+ EXPECT_EQ(root, param0);
+}
+
+// Test that get_element(make_tuple({A,B}),1) is simplified to B
+TEST_F(AlgebraicSimplifierTest, SelectMakeTuple) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32, "param0"));
+ HloInstruction* param1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, r0f32, "param1"));
+ HloInstruction* param2 = builder.AddInstruction(
+ HloInstruction::CreateParameter(2, r0f32, "param2"));
+ HloInstruction* tuple =
+ builder.AddInstruction(HloInstruction::CreateTuple({param0, param1}));
+ HloInstruction* get = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(r0f32, tuple, 1));
+ HloInstruction* add = builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, get, param2));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root, add);
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ root = computation->root_instruction();
+ EXPECT_EQ(root, add);
+ EXPECT_EQ(root->operand(0), param1);
+ EXPECT_EQ(root->operand(1), param2);
+}
+
+// Test that exp(A)/exp(B) is simplified to exp(A-B)
+TEST_F(AlgebraicSimplifierTest, ExpDiv) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32, "param0"));
+ HloInstruction* param1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, r0f32, "param1"));
+ HloInstruction* exp0 = builder.AddInstruction(
+ HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
+ HloInstruction* exp1 = builder.AddInstruction(
+ HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param1));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, exp0, exp1));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kDivide);
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kExp);
+ EXPECT_EQ(root->operand_count(), 1);
+ EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kSubtract);
+ EXPECT_EQ(root->operand(0)->operand(0), param0);
+ EXPECT_EQ(root->operand(0)->operand(1), param1);
+}
+
+// Test that ln(exp(A)) is simplified to A
+TEST_F(AlgebraicSimplifierTest, LnExp) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32, "param0"));
+ HloInstruction* exp0 = builder.AddInstruction(
+ HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
+ builder.AddInstruction(
+ HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, exp0));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kLog);
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kParameter);
+ EXPECT_EQ(root, param0);
+}
+
+// Test that ln(exp(A)/exp(B)) is simplified to A-B
+TEST_F(AlgebraicSimplifierTest, LnExpDiv) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32, "param0"));
+ HloInstruction* param1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, r0f32, "param1"));
+ HloInstruction* exp0 = builder.AddInstruction(
+ HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
+ HloInstruction* exp1 = builder.AddInstruction(
+ HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param1));
+ HloInstruction* div = builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, exp0, exp1));
+ builder.AddInstruction(
+ HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, div));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kLog);
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kSubtract);
+ EXPECT_EQ(root->operand(0), param0);
+ EXPECT_EQ(root->operand(1), param1);
+}
+
+// Test that pow(A, 0) where A is a scalar is simplified to the scalar
+// constant 1.
+TEST_F(AlgebraicSimplifierTest, Pow0Scalar) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32, "param0"));
+ HloInstruction* zero = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, zero));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kConstant);
+ EXPECT_EQ(LiteralUtil::GetFirstElement<float>(root->literal()), 1);
+}
+
+// Test that pow(A, 0) where A is not a scalar is simplified to broadcast(1).
+TEST_F(AlgebraicSimplifierTest, Pow0Vector) {
+ Shape r1f32 = ShapeUtil::MakeShape(F32, {42});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r1f32, "param0"));
+ HloInstruction* zero = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param0, zero));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
+ EXPECT_TRUE(ShapeUtil::Equal(root->shape(), r1f32))
+ << ShapeUtil::HumanString(root->shape());
+ EXPECT_EQ(root->dimensions().size(), 0);
+ EXPECT_TRUE(ShapeUtil::IsScalar(root->operand(0)->shape()));
+ EXPECT_EQ(LiteralUtil::GetFirstElement<float>(root->operand(0)->literal()),
+ 1);
+}
+
+// Test that pow(A, 1) is simplified to A.
+TEST_F(AlgebraicSimplifierTest, Pow1) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32, "param0"));
+ HloInstruction* one = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1)));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, one));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kParameter);
+ EXPECT_EQ(root, param0);
+}
+
+// Test that pow(A, 2) is simplified to A*A.
+TEST_F(AlgebraicSimplifierTest, Pow2) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32, "param0"));
+ HloInstruction* two = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2)));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, two));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kMultiply);
+ EXPECT_EQ(root->operand(0), param0);
+ EXPECT_EQ(root->operand(1), param0);
+}
+
+// Test that pow(A, -1) is simplified to 1/A.
+TEST_F(AlgebraicSimplifierTest, PowNegative1) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32, "param0"));
+ HloInstruction* negative_one = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(-1)));
+ builder.AddInstruction(HloInstruction::CreateBinary(r0f32, HloOpcode::kPower,
+ param0, negative_one));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kDivide);
+ EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kConstant);
+ EXPECT_EQ(LiteralUtil::GetFirstElement<float>(root->operand(0)->literal()),
+ 1);
+ EXPECT_EQ(root->operand(1), param0);
+}
+
+TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+
+ auto builder = HloComputation::Builder(TestName());
+ auto op = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {3, 2}), "op"));
+ auto reshape1 = builder.AddInstruction(
+ HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6}), op));
+ auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(F32, {1, 6}), reshape1, {1}));
+ builder.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {3, 2}), broadcast));
+
+ auto computation = builder.Build();
+ auto module = MakeUnique<HloModule>(TestName());
+ module->AddEntryComputation(std::move(computation));
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ HloPassFix<AlgebraicSimplifier> simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ root = module->entry_computation()->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kParameter);
+}
+
+// Test that convert(A, $TYPE) is simplified to A if A is of type $TYPE.
+TEST_F(AlgebraicSimplifierTest, ConvertBetweenSameType) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* input = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
+ builder.AddInstruction(
+ HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(HloOpcode::kConvert, computation->root_instruction()->opcode());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+
+ EXPECT_EQ(HloOpcode::kConstant, computation->root_instruction()->opcode());
+}
+
+TEST_F(AlgebraicSimplifierTest, ConvertF32ToS64) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* input = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
+ builder.AddInstruction(
+ HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {}), input));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(HloOpcode::kConvert, computation->root_instruction()->opcode());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+
+ EXPECT_EQ(HloOpcode::kConstant, computation->root_instruction()->opcode());
+ EXPECT_EQ(LiteralUtil::GetFirstElement<int64>(
+ computation->root_instruction()->literal()),
+ 42);
+}
+
+TEST_F(AlgebraicSimplifierTest, ConvertS64ToF32) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* input = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int64>(42)));
+ builder.AddInstruction(
+ HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(HloOpcode::kConvert, computation->root_instruction()->opcode());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+
+ EXPECT_EQ(HloOpcode::kConstant, computation->root_instruction()->opcode());
+ EXPECT_EQ(LiteralUtil::GetFirstElement<float>(
+ computation->root_instruction()->literal()),
+ 42.0f);
+}
+
+TEST_F(AlgebraicSimplifierTest, ConvertF32ArrayToS64Array) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* input = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({42.0f, 19.0f})));
+ builder.AddInstruction(
+ HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {2}), input));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(HloOpcode::kConvert, computation->root_instruction()->opcode());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+
+ EXPECT_EQ(HloOpcode::kConstant, computation->root_instruction()->opcode());
+ EXPECT_EQ(
+ LiteralUtil::Get<int64>(computation->root_instruction()->literal(), {0}),
+ 42);
+ EXPECT_EQ(
+ LiteralUtil::Get<int64>(computation->root_instruction()->literal(), {1}),
+ 19);
+}
+
+// Test that copies are removed.
+TEST_F(AlgebraicSimplifierTest, RemoveCopy) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32, "param0"));
+ HloInstruction* copy = builder.AddInstruction(
+ HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(copy, computation->root_instruction());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+
+ EXPECT_EQ(param0, computation->root_instruction());
+}
+
+// Test that a simplification which changes layouts is not performed if layout
+// sensitive is true.
+TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
+ HloInstruction* copy = builder.AddInstruction(
+ HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ // Set to different layouts.
+ *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
+ *copy->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0});
+
+ EXPECT_EQ(copy, computation->root_instruction());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
+ non_bitcasting_callback());
+ EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie());
+
+ // Copy has not been removed.
+ EXPECT_EQ(copy, computation->root_instruction());
+}
+
+// Test that a simplification which preserves layouts is performed if layout
+// sensitive is true.
+TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
+ HloInstruction* copy = builder.AddInstruction(
+ HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ // Set to same layouts.
+ *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
+ *copy->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
+
+ EXPECT_EQ(copy, computation->root_instruction());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+
+ // Copy has been removed.
+ EXPECT_EQ(param0, computation->root_instruction());
+}
+
+// Test that a reshape which could be replaced with a bitcast is not if
+// add_bitcasts is false.
+TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
+ HloInstruction* reshape =
+ builder.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), param0));
+
+ *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
+ *reshape->mutable_shape()->mutable_layout() =
+ LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5});
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(reshape, computation->root_instruction());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
+ non_bitcasting_callback());
+ EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie());
+
+ // Reshape is not replaced with a bitcast.
+ EXPECT_EQ(reshape, computation->root_instruction());
+}
+
+// Test transforming reshapes to bitcasts under various conditions.
+TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
+ *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
+
+ // Reshape which can be transformed into a bitcast.
+ HloInstruction* transformable_reshape =
+ builder.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), param0));
+ *transformable_reshape->mutable_shape()->mutable_layout() =
+ LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5});
+
+ // Reshape does not just add degenerate dimensions.
+ HloInstruction* dimensions_wrong_reshape =
+ builder.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {1, 4, 1, 1, 1, 1}), param0));
+ *dimensions_wrong_reshape->mutable_shape()->mutable_layout() =
+ LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5});
+
+ // Reshape has wrong layout.
+ HloInstruction* layout_wrong_reshape =
+ builder.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), param0));
+ *layout_wrong_reshape->mutable_shape()->mutable_layout() =
+ LayoutUtil::MakeLayout({5, 4, 3, 2, 1, 0});
+
+ // Collect all the reshapes into a tuple so they are not dead.
+ builder.AddInstruction(HloInstruction::CreateTuple(
+ {transformable_reshape, dimensions_wrong_reshape, layout_wrong_reshape}));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(transformable_reshape, computation->root_instruction()->operand(0));
+ EXPECT_EQ(dimensions_wrong_reshape,
+ computation->root_instruction()->operand(1));
+ EXPECT_EQ(layout_wrong_reshape, computation->root_instruction()->operand(2));
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
+ bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+
+ // Verify that only the first reshape is replaced.
+ EXPECT_NE(transformable_reshape, computation->root_instruction()->operand(0));
+ EXPECT_EQ(HloOpcode::kBitcast,
+ computation->root_instruction()->operand(0)->opcode());
+ EXPECT_EQ(dimensions_wrong_reshape,
+ computation->root_instruction()->operand(1));
+ EXPECT_EQ(layout_wrong_reshape, computation->root_instruction()->operand(2));
+}
+
+TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {50, 14, 14, 64}), "param"));
+ *param->mutable_shape()->mutable_layout() =
+ LayoutUtil::MakeLayout({1, 2, 0, 3});
+
+ HloInstruction* transpose =
+ builder.AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::MakeShape(F32, {14, 14, 50, 64}), param, {1, 2, 0, 3}));
+ *transpose->mutable_shape()->mutable_layout() =
+ LayoutUtil::MakeLayout({0, 1, 2, 3});
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
+ bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+
+ // Verify that the reshape is replaced.
+ EXPECT_EQ(2, computation->instruction_count());
+ EXPECT_EQ(HloOpcode::kBitcast, computation->root_instruction()->opcode());
+}
+
+TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast2) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {5, 2, 3, 4}), "param"));
+ *param->mutable_shape()->mutable_layout() =
+ LayoutUtil::MakeLayout({1, 2, 3, 0});
+
+ HloInstruction* transpose =
+ builder.AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::MakeShape(F32, {5, 3, 4, 2}), param, {0, 2, 3, 1}));
+ *transpose->mutable_shape()->mutable_layout() =
+ LayoutUtil::MakeLayout({3, 1, 2, 0});
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
+ bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+
+ // Verify that the reshape is replaced.
+ EXPECT_EQ(2, computation->instruction_count());
+ EXPECT_EQ(HloOpcode::kBitcast, computation->root_instruction()->opcode());
+}
+
+TEST_F(AlgebraicSimplifierTest, ReshapesMerged) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
+
+ HloInstruction* reshape1 =
+ builder.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {2, 1, 2}), param0));
+
+ HloInstruction* reshape2 =
+ builder.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), reshape1));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(reshape2, computation->root_instruction());
+ EXPECT_EQ(reshape1, computation->root_instruction()->operand(0));
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+
+ EXPECT_EQ(HloOpcode::kReshape, computation->root_instruction()->opcode());
+ EXPECT_EQ(HloOpcode::kParameter,
+ computation->root_instruction()->operand(0)->opcode());
+}
+
+TEST_F(AlgebraicSimplifierTest, TransposesMerged) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {2, 3, 4}), "param0"));
+
+ HloInstruction* transpose1 =
+ builder.AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::MakeShape(F32, {3, 4, 2}), param0, {1, 2, 0}));
+
+ HloInstruction* transpose2 =
+ builder.AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::MakeShape(F32, {4, 3, 2}), transpose1, {1, 0, 2}));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(transpose2, computation->root_instruction());
+ EXPECT_EQ(transpose1, computation->root_instruction()->operand(0));
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+
+ EXPECT_EQ(HloOpcode::kTranspose, computation->root_instruction()->opcode());
+ EXPECT_EQ(std::vector<int64>({2, 1, 0}),
+ computation->root_instruction()->dimensions());
+ EXPECT_EQ(HloOpcode::kParameter,
+ computation->root_instruction()->operand(0)->opcode());
+}
+
+// Test merging reshape and broadcast.
+TEST_F(AlgebraicSimplifierTest, ReshapeAndBroadcastMerged) {
+ HloComputation::Builder builder(TestName());
+ auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {5}), "param0"));
+ auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {1, 5, 1}), param0));
+ builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(F32, {1, 2, 3, 5, 1}), reshape1, {0, 2, 3}));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+
+ EXPECT_EQ(HloOpcode::kBroadcast, computation->root_instruction()->opcode());
+ EXPECT_EQ(HloOpcode::kParameter,
+ computation->root_instruction()->operand(0)->opcode());
+}
+
+// Test merging broadcast and reshape.
+TEST_F(AlgebraicSimplifierTest, BroadcastAndReshapeMerged) {
+ HloComputation::Builder builder(TestName());
+ auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {2, 3}), "param0"));
+ auto broadcast1 = builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(F32, {1, 2, 3, 7, 12, 1}), param0, {1, 2}));
+ builder.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {2, 3, 7, 2, 1, 3, 2}), broadcast1));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+
+ EXPECT_EQ(HloOpcode::kBroadcast, computation->root_instruction()->opcode());
+ EXPECT_EQ(HloOpcode::kParameter,
+ computation->root_instruction()->operand(0)->opcode());
+}
+
+TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x1_3) {
+ HloComputation::Builder builder(TestName());
+ auto param = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {1}), "param"));
+ auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(F32, {3, 1}), param, {1}));
+ builder.AddInstruction(
+ HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {3}), broadcast));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ module->AddEntryComputation(builder.Build());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie());
+}
+
+TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4_6x1x1x4) {
+ HloComputation::Builder builder(TestName());
+ auto param = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {4}), "param"));
+ auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(F32, {3, 2, 4}), param, {2}));
+ builder.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {6, 1, 1, 4}), broadcast));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ HloComputation* computation = module->AddEntryComputation(builder.Build());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ EXPECT_EQ(HloOpcode::kBroadcast, computation->root_instruction()->opcode());
+ EXPECT_MATCH(computation->root_instruction()->dimensions(),
+ testing::VectorMatcher<int64>({3}));
+}
+
+TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x2x1_6x1x1x1) {
+ HloComputation::Builder builder(TestName());
+ auto param = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {1}), "param"));
+ auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(F32, {3, 2, 1}), param, {2}));
+ builder.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {6, 1, 1, 1}), broadcast));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ HloComputation* computation = module->AddEntryComputation(builder.Build());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ EXPECT_EQ(HloOpcode::kBroadcast, computation->root_instruction()->opcode());
+ const std::vector<int64> broadcast_dims =
+ computation->root_instruction()->dimensions();
+ EXPECT_EQ(1, broadcast_dims.size());
+ EXPECT_TRUE(broadcast_dims[0] == 1 || broadcast_dims[0] == 2 ||
+ broadcast_dims[3] == 3);
+}
+
+TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) {
+ HloComputation::Builder builder(TestName());
+ auto param = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {4}), "param"));
+ auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(F32, {3, 2, 4, 2}), param, {2}));
+ builder.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {6, 8}), broadcast));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ module->AddEntryComputation(builder.Build());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie());
+}
+
+TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {2, 2}), "param"));
+ HloInstruction* zero = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
+ PaddingConfig no_padding;
+ for (auto i = 0; i < 2; ++i) {
+ auto dimension = no_padding.add_dimensions();
+ dimension->set_edge_padding_low(0);
+ dimension->set_edge_padding_high(0);
+ dimension->set_interior_padding(0);
+ }
+ builder.AddInstruction(HloInstruction::CreatePad(
+ ShapeUtil::MakeShape(F32, {2, 2}), param, zero, no_padding));
+
+ HloModule module(TestName());
+ HloComputation* computation = module.AddEntryComputation(builder.Build());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ EXPECT_EQ(1, computation->instruction_count());
+}
+
+TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {2, 3}), "param"));
+ builder.AddInstruction(
+ HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {2, 3}), param));
+
+ HloModule module(TestName());
+ HloComputation* computation = module.AddEntryComputation(builder.Build());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ EXPECT_EQ(1, computation->instruction_count());
+}
+
+TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) {
+ HloComputation::Builder builder(TestName());
+ const int64 dim0 = 2;
+ const int64 dim1 = 3;
+ HloInstruction* param =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {dim0, dim1}), "param"));
+ builder.AddInstruction(HloInstruction::CreateSlice(
+ ShapeUtil::MakeShape(F32, {dim0, dim1}), param, /*start_indices=*/{0, 0},
+ /*limit_indices=*/{dim0, dim1}));
+
+ HloModule module(TestName());
+ HloComputation* computation = module.AddEntryComputation(builder.Build());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ EXPECT_EQ(1, computation->instruction_count());
+}
+
+TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
+ struct ConvTestOptions {
+ int in_batch = 10;
+ int in_height = 2;
+ int in_width = 2;
+ int in_channels = 3;
+ int f_width = 1;
+ int f_height = 1;
+ int f_output_channels = 10;
+ int row_stride = 1;
+ int row_padding = 0;
+ int col_stride = 1;
+ int col_padding = 0;
+ bool input_minor_to_major_layout = false;
+ bool filter_minor_to_major_layout = false;
+ bool output_minor_to_major_layout = false;
+
+ const char* dim_order = "NHWC"; // can use chars NHWC in any order.
+ const char* kernel_dim_order = "HWIO"; // can use chars HWIO in any order.
+
+ ConvTestOptions& Reset() {
+ *this = ConvTestOptions();
+ return *this;
+ }
+ };
+
+ ConvTestOptions options;
+
+ // Builds a convolution from <options> and runs algebraic simplification on
+ // the computation. Returns a string description of the result of
+ // simplification.
+ auto build_and_simplify = [&options, this]() -> string {
+ HloComputation::Builder b(TestName());
+
+ Window window;
+ auto* f_dim_1 = window.add_dimensions();
+ f_dim_1->set_size(options.f_height);
+ f_dim_1->set_stride(options.row_stride);
+ f_dim_1->set_padding_low(options.row_padding);
+ f_dim_1->set_padding_high(options.row_padding);
+ f_dim_1->set_window_dilation(1);
+ f_dim_1->set_base_dilation(1);
+ auto* f_dim_2 = window.add_dimensions();
+ f_dim_2->set_size(options.f_width);
+ f_dim_2->set_stride(options.col_stride);
+ f_dim_2->set_padding_low(options.col_padding);
+ f_dim_2->set_padding_high(options.col_padding);
+ f_dim_2->set_window_dilation(1);
+ f_dim_2->set_base_dilation(1);
+
+ ConvolutionDimensionNumbers dnums;
+ std::vector<int64> in_dims;
+ int in_channel_idx = -1;
+ dnums.add_spatial_dimensions(-1); // filled in later
+ dnums.add_spatial_dimensions(-1); // filled in later
+ for (int i = 0; i < strlen(options.dim_order); ++i) {
+ char ch = options.dim_order[i];
+ if (ch == 'N') {
+ dnums.set_batch_dimension(i);
+ in_dims.push_back(options.in_batch);
+ } else if (ch == 'H') {
+ dnums.set_spatial_dimensions(0, i);
+ in_dims.push_back(options.in_height);
+ } else if (ch == 'W') {
+ dnums.set_spatial_dimensions(1, i);
+ in_dims.push_back(options.in_width);
+ } else if (ch == 'C') {
+ dnums.set_feature_dimension(i);
+ in_dims.push_back(options.in_channels);
+ in_channel_idx = i;
+ }
+ }
+
+ std::vector<int64> f_dims;
+ dnums.add_kernel_spatial_dimensions(-1); // filled in later
+ dnums.add_kernel_spatial_dimensions(-1); // filled in later
+ for (int i = 0; i < strlen(options.kernel_dim_order); ++i) {
+ char ch = options.kernel_dim_order[i];
+ if (ch == 'H') {
+ dnums.set_kernel_spatial_dimensions(0, i);
+ f_dims.push_back(options.f_height);
+ } else if (ch == 'W') {
+ dnums.set_kernel_spatial_dimensions(1, i);
+ f_dims.push_back(options.f_width);
+ } else if (ch == 'I') {
+ dnums.set_kernel_input_feature_dimension(i);
+ f_dims.push_back(options.in_channels);
+ } else if (ch == 'O') {
+ dnums.set_kernel_output_feature_dimension(i);
+ f_dims.push_back(options.f_output_channels);
+ }
+ }
+
+ auto out_dims = in_dims;
+ out_dims[in_channel_idx] = options.f_output_channels;
+
+ auto make_shape = [](tensorflow::gtl::ArraySlice<int64> dims,
+ bool minor_to_major_layout) {
+ if (minor_to_major_layout) {
+ return ShapeUtil::MakeShapeWithLayout(F32, dims, {0, 1, 2, 3});
+ } else {
+ return ShapeUtil::MakeShape(F32, dims);
+ }
+ };
+ auto in_shape = make_shape(in_dims, options.input_minor_to_major_layout);
+ auto f_shape = make_shape(f_dims, options.filter_minor_to_major_layout);
+ auto out_shape = make_shape(out_dims, options.output_minor_to_major_layout);
+
+ HloInstruction* input =
+ b.AddInstruction(HloInstruction::CreateParameter(0, in_shape, "input"));
+ HloInstruction* filter =
+ b.AddInstruction(HloInstruction::CreateParameter(1, f_shape, "filter"));
+
+ b.AddInstruction(HloInstruction::CreateConvolve(out_shape, input, filter,
+ window, dnums));
+
+ HloModule module(TestName());
+ auto* computation = module.AddEntryComputation(b.Build());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
+ bitcasting_callback());
+ if (!simplifier.Run(&module).ValueOrDie()) {
+ return "NO_CHANGE";
+ }
+ auto* root = computation->root_instruction();
+ if (root->opcode() == HloOpcode::kBitcast &&
+ root->operand(0)->opcode() == HloOpcode::kDot) {
+ auto lhs_shape = root->operand(0)->operand(0)->shape();
+ auto rhs_shape = root->operand(0)->operand(1)->shape();
+ return tensorflow::strings::StrCat(
+ tensorflow::str_util::Join(lhs_shape.dimensions(), "x"), " DOT ",
+ tensorflow::str_util::Join(rhs_shape.dimensions(), "x"));
+ }
+ return "UNEXPECTED CHANGE";
+ };
+
+ // Default options are the simplest case and succeed.
+ options.Reset();
+ EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
+
+ // Swapping dim spatial and batch order works.
+ options.Reset().dim_order = "NWHC";
+ EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
+ options.Reset().dim_order = "WHNC";
+ EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
+ // Channel dimension earlier fails.
+ options.Reset().dim_order = "HWCN";
+ EXPECT_EQ("NO_CHANGE", build_and_simplify());
+ options.Reset().dim_order = "CHWN";
+ EXPECT_EQ("NO_CHANGE", build_and_simplify());
+
+ // Filtering dims spatial dims can be anywhere, since they are 1x1.
+ options.Reset().kernel_dim_order = "WHIO";
+ EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
+ options.Reset().kernel_dim_order = "IWOH";
+ EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
+ options.Reset().kernel_dim_order = "IWHO";
+ EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
+ // But moving output channel before input channel fails.
+ options.Reset().kernel_dim_order = "HWOI";
+ EXPECT_EQ("NO_CHANGE", build_and_simplify());
+ options.Reset().kernel_dim_order = "WHOI";
+ EXPECT_EQ("NO_CHANGE", build_and_simplify());
+ options.Reset().kernel_dim_order = "OWIH";
+ EXPECT_EQ("NO_CHANGE", build_and_simplify());
+ options.Reset().kernel_dim_order = "OWHI";
+ EXPECT_EQ("NO_CHANGE", build_and_simplify());
+
+ // Combine different dim and kernel dim orders.
+ options.Reset().kernel_dim_order = "IWHO";
+ options.dim_order = "WHNC";
+ EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
+
+ // Test invalid cases from wrong filter size, strides, or padding.
+ options.Reset().f_width = 2;
+ EXPECT_EQ("NO_CHANGE", build_and_simplify());
+ options.Reset().f_height = 2;
+ EXPECT_EQ("NO_CHANGE", build_and_simplify());
+ options.Reset().row_stride = 2;
+ EXPECT_EQ("NO_CHANGE", build_and_simplify());
+ options.Reset().col_stride = 2;
+ EXPECT_EQ("NO_CHANGE", build_and_simplify());
+ options.Reset().col_padding = 1;
+ EXPECT_EQ("NO_CHANGE", build_and_simplify());
+ options.Reset().row_padding = 1;
+ EXPECT_EQ("NO_CHANGE", build_and_simplify());
+
+ // The default dim_order is "NHWC". Col-major layout makes C the most major.
+ options.Reset().input_minor_to_major_layout = true;
+ options.output_minor_to_major_layout = true;
+ EXPECT_EQ("NO_CHANGE", build_and_simplify());
+
+ // The input and output have different layouts.
+ options.Reset().input_minor_to_major_layout = true;
+ EXPECT_EQ("NO_CHANGE", build_and_simplify());
+
+ // C is most minor, and I is more major than O.
+ options.Reset().input_minor_to_major_layout = true;
+ options.filter_minor_to_major_layout = true;
+ options.output_minor_to_major_layout = true;
+ options.dim_order = "CHWN";
+ options.kernel_dim_order = "OIHW";
+ EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
+
+ // C is not the most minor dimension.
+ options.Reset().input_minor_to_major_layout = true;
+ options.filter_minor_to_major_layout = true;
+ options.output_minor_to_major_layout = true;
+ options.dim_order = "HWNC";
+ options.kernel_dim_order = "OIHW";
+ EXPECT_EQ("NO_CHANGE", build_and_simplify());
+
+ // I is more minor than O.
+ options.Reset().input_minor_to_major_layout = true;
+ options.filter_minor_to_major_layout = true;
+ options.output_minor_to_major_layout = true;
+ options.dim_order = "CHWN";
+ options.kernel_dim_order = "IOHW";
+ EXPECT_EQ("NO_CHANGE", build_and_simplify());
+}
+
+// Test that max(min(A, x), y) is transformed to clamp(y, A, x)
+TEST_F(AlgebraicSimplifierTest, MaxMinToClamp) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32, "param0"));
+ HloInstruction* min_value = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
+ HloInstruction* max_value = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
+ HloInstruction* min = builder.AddInstruction(HloInstruction::CreateBinary(
+ r0f32, HloOpcode::kMinimum, param0, min_value));
+ HloInstruction* max = builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32, HloOpcode::kMaximum, min, max_value));
+
+ HloModule module(TestName());
+ auto computation = module.AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root, max);
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ root = computation->root_instruction();
+ ASSERT_EQ(root->opcode(), HloOpcode::kClamp);
+ EXPECT_EQ(root->operand(0), max_value);
+ EXPECT_EQ(root->operand(1), param0);
+ EXPECT_EQ(root->operand(2), min_value);
+}
+
+// Test that min(max(A, x), y) is transformed to clamp(x, A, y) for scalar
+// values.
+TEST_F(AlgebraicSimplifierTest, MinMaxToClamp) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32, "param0"));
+ HloInstruction* min_value = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
+ HloInstruction* max_value = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
+ HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary(
+ r0f32, HloOpcode::kMaximum, param0, max_value));
+ HloInstruction* min = builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32, HloOpcode::kMinimum, max, min_value));
+
+ HloModule module(TestName());
+ auto computation = module.AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root, min);
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kClamp);
+ EXPECT_EQ(root->operand(0), max_value);
+ EXPECT_EQ(root->operand(1), param0);
+ EXPECT_EQ(root->operand(2), min_value);
+}
+
+// Test that min(max(A, x), y) is transformed to clamp(x, A, y) for
+// broadcasted scalar values.
+TEST_F(AlgebraicSimplifierTest, MinMaxWithBroadcastToClamp) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+ Shape r1f32 = ShapeUtil::MakeShape(F32, {100});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r1f32, "param0"));
+ HloInstruction* min_value = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
+ HloInstruction* max_value = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
+ HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary(
+ r1f32, HloOpcode::kMaximum, param0, max_value));
+ HloInstruction* min = builder.AddInstruction(
+ HloInstruction::CreateBinary(r1f32, HloOpcode::kMinimum, max, min_value));
+
+ HloModule module(TestName());
+ auto computation = module.AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root, min);
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kClamp);
+ EXPECT_EQ(root->operand(0), max_value);
+ EXPECT_EQ(root->operand(1), param0);
+ EXPECT_EQ(root->operand(2), min_value);
+}
+
+// Test that min(max(A, non-constant1), non-constant2) is not canonicalized to
+// clamp(non-constant1, A, non-constant2)
+TEST_F(AlgebraicSimplifierTest, MinMaxNotToClamp) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32, "param0"));
+ HloInstruction* min_value = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, r0f32, "param1"));
+ HloInstruction* max_value = builder.AddInstruction(
+ HloInstruction::CreateParameter(2, r0f32, "param2"));
+ HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary(
+ r0f32, HloOpcode::kMaximum, param0, max_value));
+ HloInstruction* min = builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32, HloOpcode::kMinimum, max, min_value));
+
+ HloModule module(TestName());
+ auto computation = module.AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ EXPECT_FALSE(simplifier.Run(&module).ValueOrDie());
+ root = computation->root_instruction();
+ EXPECT_EQ(root, min);
+}
+
+// Test that min(f(max(A, constant1)), constant2) is not transformed to
+// clamp(constant1, A, constant2)
+TEST_F(AlgebraicSimplifierTest, MinEquationWithMaxNotToClamp) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32, "param0"));
+ HloInstruction* min_value = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
+ HloInstruction* max_value = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
+ HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary(
+ r0f32, HloOpcode::kMaximum, param0, max_value));
+ HloInstruction* fmax = builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, max, max_value));
+ HloInstruction* min = builder.AddInstruction(HloInstruction::CreateBinary(
+ r0f32, HloOpcode::kMinimum, fmax, min_value));
+
+ HloModule module(TestName());
+ auto computation = module.AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root, min);
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ EXPECT_FALSE(simplifier.Run(&module).ValueOrDie());
+ root = computation->root_instruction();
+ EXPECT_EQ(root, min);
+}
+
+// Test that slice(broadcast(/*scalar value*/)) simplifies to a single
+// broadcast.
+TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* scalar_param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32, "scalar_param"));
+
+ Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6, 7});
+ HloInstruction* broadcast =
+ builder.AddInstruction(HloInstruction::CreateBroadcast(
+ broadcast_shape, scalar_param,
+ AsInt64Slice(broadcast_shape.dimensions())));
+
+ Shape slice_shape = ShapeUtil::MakeShape(F32, {2, 2, 3, 3});
+ HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateSlice(
+ slice_shape, broadcast, {0, 1, 2, 3}, {2, 3, 5, 6}));
+
+ HloModule module(TestName());
+ auto computation = module.AddEntryComputation(builder.Build());
+
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root, slice);
+ EXPECT_TRUE(ShapeUtil::Equal(root->shape(), slice_shape));
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+
+ root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
+ EXPECT_EQ(scalar_param, root->operand(0));
+ EXPECT_TRUE(ShapeUtil::Equal(root->shape(), slice_shape));
+}
+
+// Test that reshape(transpose(broadcast(/*scalar value*/))) simplifies to a
+// single broadcast.
+TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* forty_two = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
+
+ Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6});
+ HloInstruction* broadcast =
+ builder.AddInstruction(HloInstruction::CreateBroadcast(
+ broadcast_shape, forty_two,
+ AsInt64Slice(broadcast_shape.dimensions())));
+
+ HloInstruction* transpose =
+ builder.AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::MakeShape(F32, {6, 5, 4}), broadcast, {2, 1, 0}));
+
+ Shape reshape_shape = ShapeUtil::MakeShape(F32, {30, 1, 4});
+ HloInstruction* reshape = builder.AddInstruction(
+ HloInstruction::CreateReshape(reshape_shape, transpose));
+
+ HloModule module(TestName());
+ auto computation = module.AddEntryComputation(builder.Build());
+
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root, reshape);
+ EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reshape_shape));
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+
+ root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
+ EXPECT_EQ(forty_two, root->operand(0));
+ EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reshape_shape));
+}
+
+} // namespace
+} // namespace xla