aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/algebraic_simplifier_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc214
1 files changed, 158 insertions, 56 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index b733f6f59e..8b81b4c97e 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include <utility>
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -60,7 +60,7 @@ TEST_F(AlgebraicSimplifierTest, AddZero) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "param0"));
HloInstruction* zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, zero));
@@ -74,12 +74,32 @@ TEST_F(AlgebraicSimplifierTest, AddZero) {
EXPECT_EQ(root, param0);
}
+// Test that A * 0 is simplified to 0
+TEST_F(AlgebraicSimplifierTest, MulZero) {
+ Shape r0s32 = ShapeUtil::MakeShape(S32, {});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0s32, "param0"));
+ HloInstruction* zero = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(r0s32, HloOpcode::kMultiply, param0, zero));
+
+ auto computation = module().AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kMultiply);
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
+ EXPECT_EQ(computation->root_instruction(), zero);
+}
+
// Test that Reduce(Reduce(A)) -> Reduce(A)
TEST_F(AlgebraicSimplifierTest, TwoReducesToOne) {
HloComputation::Builder builder(TestName());
// Create add computation.
HloInstruction* zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
HloComputation* add_computation = nullptr;
{
HloComputation::Builder builder(TestName() + ".add");
@@ -119,7 +139,7 @@ TEST_F(AlgebraicSimplifierTest, AddConstOnLHS) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "param0"));
HloInstruction* constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f)));
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, constant, param0));
@@ -140,9 +160,9 @@ TEST_F(AlgebraicSimplifierTest, AddReassociateMergeConstants) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "param0"));
HloInstruction* constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f)));
HloInstruction* constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(3.14159f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.14159f)));
HloInstruction* add1 = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, constant1));
@@ -165,7 +185,7 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r2f32, "param0"));
HloInstruction* zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
HloInstruction* bcast = builder.AddInstruction(
HloInstruction::CreateBroadcast(r2f32, zero, {0, 1}));
builder.AddInstruction(
@@ -200,7 +220,7 @@ TEST_F(AlgebraicSimplifierTest, InlineTrivialMap) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r2f32, "param0"));
HloInstruction* zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
builder.AddInstruction(HloInstruction::CreateMap(
r2f32,
{param0, builder.AddInstruction(
@@ -223,7 +243,7 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r2f32, "param0"));
HloInstruction* zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>({0, 0, 0})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({0, 0, 0})));
HloInstruction* bcast =
builder.AddInstruction(HloInstruction::CreateBroadcast(r2f32, zero, {1}));
builder.AddInstruction(
@@ -242,7 +262,7 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) {
TEST_F(AlgebraicSimplifierTest, ConstantToBroadcast) {
HloComputation::Builder builder(TestName());
builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({3.14f, 3.14f, 3.14f})));
+ LiteralUtil::CreateR1<float>({3.14f, 3.14f, 3.14f})));
auto computation = module().AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
@@ -258,7 +278,7 @@ TEST_F(AlgebraicSimplifierTest, ConstantToBroadcast) {
TEST_F(AlgebraicSimplifierTest, ConstantNotToBroadcast) {
HloComputation::Builder builder(TestName());
builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({3.14, 3.14, 4})));
+ LiteralUtil::CreateR1<float>({3.14, 3.14, 4})));
auto computation = module().AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
@@ -277,7 +297,7 @@ TEST_F(AlgebraicSimplifierTest, SubZero) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "param0"));
HloInstruction* zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kSubtract, param0, zero));
@@ -298,7 +318,7 @@ TEST_F(AlgebraicSimplifierTest, SubConstCanonicalization) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "param0"));
HloInstruction* constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
builder.AddInstruction(HloInstruction::CreateBinary(
r0f32, HloOpcode::kSubtract, param0, constant));
@@ -493,7 +513,7 @@ TEST_F(AlgebraicSimplifierTest, DivideByConstant) {
HloInstruction::CreateParameter(0, r1f32, "param0"));
HloInstruction* constant =
builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({0.f, 1.f, 2.f})));
+ LiteralUtil::CreateR1<float>({0.f, 1.f, 2.f})));
builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide,
param0, constant));
@@ -559,7 +579,7 @@ TEST_F(AlgebraicSimplifierTest, DivOneScalar) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "param0"));
HloInstruction* one = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
HloInstruction* div = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, one));
@@ -580,7 +600,7 @@ TEST_F(AlgebraicSimplifierTest, DivOneArray) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r2f32, "param0"));
HloInstruction* one = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1.0, 1.0}, {1.0, 1.0}})));
+ LiteralUtil::CreateR2<float>({{1.0, 1.0}, {1.0, 1.0}})));
HloInstruction* div = builder.AddInstruction(
HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, one));
@@ -860,7 +880,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Scalar) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "param0"));
HloInstruction* zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, zero));
@@ -884,7 +904,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Vector) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r1f32, "param0"));
HloInstruction* zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
builder.AddInstruction(
HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param0, zero));
@@ -912,7 +932,7 @@ TEST_F(AlgebraicSimplifierTest, Pow1) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "param0"));
HloInstruction* one = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1)));
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, one));
@@ -934,7 +954,7 @@ TEST_F(AlgebraicSimplifierTest, Pow2) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "param0"));
HloInstruction* two = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2)));
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, two));
@@ -956,7 +976,7 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "param0"));
HloInstruction* negative_one = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(-1)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(-1)));
builder.AddInstruction(HloInstruction::CreateBinary(r0f32, HloOpcode::kPower,
param0, negative_one));
@@ -1047,7 +1067,7 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedReduceWindow) {
builder.AddInstruction(HloInstruction::CreateReduceWindow(
ShapeUtil::MakeShape(F32, {5, 2}), param,
builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f))),
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f))),
window, add_computation));
module().AddEntryComputation(builder.Build());
HloPassFix<AlgebraicSimplifier> simplifier(/*is_layout_sensitive=*/false,
@@ -1074,7 +1094,7 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedPad) {
builder.AddInstruction(HloInstruction::CreatePad(
ShapeUtil::MakeShape(F32, {5, 2}), param,
builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(0.0f))),
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))),
padding));
module().AddEntryComputation(builder.Build());
EXPECT_THAT(module().entry_computation()->root_instruction(),
@@ -1116,7 +1136,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) {
TEST_F(AlgebraicSimplifierTest, ConvertBetweenSameType) {
HloComputation::Builder builder(TestName());
HloInstruction* input = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
builder.AddInstruction(
HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input));
@@ -1208,7 +1228,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) {
HloInstruction* param1 = builder.AddInstruction(
HloInstruction::CreateParameter(1, r1f32, "param1"));
HloInstruction* empty_literal = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>({})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({})));
HloInstruction* empty_slice =
builder.AddInstruction(HloInstruction::CreateSlice(
ShapeUtil::MakeShape(F32, {0}), param1, {42}, {42}, {1}));
@@ -1230,6 +1250,55 @@ TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) {
op::Concatenate(param0, param0, param1));
}
+// Test that reduce of concat is simplified.
+TEST_F(AlgebraicSimplifierTest, SimplifyReduceOfConcat) {
+ const int kParamLength = 100;
+ Shape r3f32 =
+ ShapeUtil::MakeShape(F32, {kParamLength, kParamLength, kParamLength});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r3f32, "param0"));
+ HloInstruction* param1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, r3f32, "param1"));
+ HloInstruction* param2 = builder.AddInstruction(
+ HloInstruction::CreateParameter(2, r3f32, "param2"));
+ Shape concat_shape =
+ ShapeUtil::MakeShape(F32, {kParamLength, 3 * kParamLength, kParamLength});
+ HloInstruction* Concatenate =
+ builder.AddInstruction(HloInstruction::CreateConcatenate(
+ concat_shape, {param0, param1, param2}, 1));
+ HloComputation* add_computation = nullptr;
+ {
+ HloComputation::Builder builder(TestName() + ".add");
+ const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
+ HloInstruction* p0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape, "p0"));
+ HloInstruction* p1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, scalar_shape, "p1"));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
+ add_computation = module().AddEmbeddedComputation(builder.Build());
+ }
+ Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 5, 6, 7});
+ Shape reduce_shape = ShapeUtil::MakeShape(F32, {kParamLength});
+
+ HloInstruction* zero = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
+ builder.AddInstruction(HloInstruction::CreateReduce(
+ reduce_shape, Concatenate, zero, {1, 2}, add_computation));
+
+ auto computation = module().AddEntryComputation(builder.Build());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
+
+ EXPECT_THAT(
+ computation->root_instruction(),
+ op::Map(op::Map(op::Reduce(param0, zero), op::Reduce(param1, zero)),
+ op::Reduce(param2, zero)));
+}
+
// Test a concatenate with only empty operands is removed.
TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) {
const int kParamLength = 100;
@@ -1238,7 +1307,7 @@ TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r1f32, "param0"));
HloInstruction* empty_literal = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>({})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({})));
HloInstruction* empty_slice =
builder.AddInstruction(HloInstruction::CreateSlice(
ShapeUtil::MakeShape(F32, {0}), param0, {42}, {42}, {1}));
@@ -1420,7 +1489,7 @@ TEST_F(AlgebraicSimplifierTest, FailureToSinkReshapeDoesntAffectChangedBit) {
builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "param0")),
builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{0, 0}, {0, 0}})))));
+ LiteralUtil::CreateR2<float>({{0, 0}, {0, 0}})))));
builder.AddInstruction(
HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {4}), add));
@@ -1443,7 +1512,7 @@ TEST_F(AlgebraicSimplifierTest, FailureToSinkBroadcastDoesntAffectChangedBit) {
builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "param0")),
builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{0, 0}, {0, 0}})))));
+ LiteralUtil::CreateR2<float>({{0, 0}, {0, 0}})))));
builder.AddInstruction(
HloInstruction::CreateBroadcast(ShapeUtil::MakeShape(F32, {2, 2, 2}), add,
@@ -1726,7 +1795,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) {
builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {2, 2}), "param"));
HloInstruction* zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
PaddingConfig no_padding;
for (int i = 0; i < 2; ++i) {
auto dimension = no_padding.add_dimensions();
@@ -1757,7 +1826,7 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) {
builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {10, 10}), "param"));
HloInstruction* zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
PaddingConfig padding;
int64 low_padding[2] = {-1, -2};
int64 high_padding[2] = {2, -3};
@@ -1839,6 +1908,39 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) {
EXPECT_THAT(computation->root_instruction(), param);
}
+TEST_F(AlgebraicSimplifierTest, SliceOfSliceToSlice) {
+ HloComputation::Builder builder(TestName());
+ const int64 dim0 = 11;
+ const int64 dim1 = 12;
+ HloInstruction* param =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {dim0, dim1}), "param"));
+ HloInstruction* original_slice =
+ builder.AddInstruction(HloInstruction::CreateSlice(
+ ShapeUtil::MakeShape(F32, {dim0 - 2, dim1 - 4}), param,
+ /*start_indices=*/{1, 2},
+ /*limit_indices=*/{dim0 - 1, dim1 - 2}, /*strides=*/{1, 1}));
+
+ builder.AddInstruction(HloInstruction::CreateSlice(
+ ShapeUtil::MakeShape(F32, {dim0 - 5, dim1 - 9}), original_slice,
+ /*start_indices=*/{2, 3},
+ /*limit_indices=*/{dim0 - 3, dim1 - 6}, /*strides=*/{1, 1}));
+ auto module = CreateNewModule();
+ HloComputation* computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_THAT(computation->root_instruction(), op::Slice(op::Slice(param)));
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
+
+ EXPECT_THAT(computation->root_instruction(), op::Slice(param));
+ EXPECT_EQ(computation->root_instruction()->slice_starts(0), 3);
+ EXPECT_EQ(computation->root_instruction()->slice_starts(1), 5);
+ EXPECT_EQ(computation->root_instruction()->slice_limits(0), dim0 - 2);
+ EXPECT_EQ(computation->root_instruction()->slice_limits(1), dim1 - 4);
+}
+
TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
struct ConvTestOptions {
int in_batch = 10;
@@ -2109,7 +2211,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) {
TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) {
HloComputation::Builder builder(TestName());
HloInstruction* forty_two = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6});
HloInstruction* broadcast = builder.AddInstruction(
@@ -2156,7 +2258,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
padding.mutable_dimensions(3)->set_edge_padding_high(2);
HloInstruction* pad_value = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(5.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(5.0f)));
HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
ShapeUtil::MakeShape(F32, {1, 3, 3, 5}), operand, pad_value, padding));
@@ -2187,7 +2289,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
const Shape reduce_window_shape =
ShapeUtil::MakeShape(F32, {111, 113, 113, 115});
HloInstruction* reduce_init_value = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(5.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(5.0f)));
HloInstruction* reduce_window =
builder.AddInstruction(HloInstruction::CreateReduceWindow(
reduce_window_shape, pad, reduce_init_value, window,
@@ -2238,7 +2340,7 @@ TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) {
padding.mutable_dimensions(3)->set_edge_padding_high(2);
HloInstruction* pad_value = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(5.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(5.0f)));
HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
ShapeUtil::MakeShape(BF16, {1, 3, 3, 5}), parameter, pad_value, padding));
@@ -2273,7 +2375,7 @@ TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) {
const Shape reduce_window_shape =
ShapeUtil::MakeShape(F32, {111, 113, 113, 115});
HloInstruction* reduce_init_value = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(5.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(5.0f)));
HloInstruction* reduce_window =
builder.AddInstruction(HloInstruction::CreateReduceWindow(
reduce_window_shape, convert, reduce_init_value, window,
@@ -2344,9 +2446,9 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) {
HloComputation::Builder call_builder(TestName() + ".Call");
HloInstruction* zero = call_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>({0.0f})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({0.0f})));
HloInstruction* one = call_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>({1.0f})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({1.0f})));
call_builder.AddInstruction(
HloInstruction::CreateCall(r1f32, {zero, one}, dot_computation.get()));
@@ -2362,9 +2464,9 @@ TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) {
HloComputation::Builder builder(TestName());
const float constant_scalar = 7.3f;
std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f};
- std::unique_ptr<Literal> value =
- Literal::MakeTuple({Literal::CreateR0<float>(constant_scalar).get(),
- Literal::CreateR1<float>(constant_vector).get()});
+ std::unique_ptr<Literal> value = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR0<float>(constant_scalar).get(),
+ LiteralUtil::CreateR1<float>(constant_vector).get()});
builder.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
auto computation = module().AddEntryComputation(builder.Build());
@@ -2387,8 +2489,8 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) {
shape,
builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "slice_from")),
- builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int>({0, 0, 0}))),
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 1, ShapeUtil::MakeShape(U32, {3}), "slice_indices")),
/*slice_sizes=*/{10, 100, 1000}));
auto computation = module().AddEntryComputation(builder.Build());
@@ -2421,8 +2523,8 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) {
builder.AddInstruction(
HloInstruction::CreateParameter(2, slice_shape, "to_update")),
slice,
- builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int>({0, 0, 0})))));
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 3, ShapeUtil::MakeShape(U32, {3}), "update_indices"))));
auto computation = module().AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
@@ -2437,7 +2539,7 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcasts) {
HloComputation::Builder builder(TestName());
Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
HloInstruction* input_array = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>({3, 4})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({3, 4})));
HloInstruction* inner_bcast = builder.AddInstruction(
HloInstruction::CreateBroadcast(r2f32, input_array, {1}));
Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 2, 2});
@@ -2546,7 +2648,7 @@ TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) {
HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
pad_shape, input,
builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(0.0f))),
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))),
padding));
HloComputation* add_computation = nullptr;
@@ -2565,7 +2667,7 @@ TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) {
Window window = window_util::MakeWindow(
decorate_spatials(param.reduce_window_spatials, 1, 1));
auto zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
TF_ASSERT_OK_AND_ASSIGN(const Shape output_shape,
ShapeInference::InferReduceWindowShape(
pad->shape(), zero->shape(), window,
@@ -2704,7 +2806,7 @@ TEST_P(DotOfConcatSimplificationTest, ConstantLHS) {
Shape lhs_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.k});
auto* lhs = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
/*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.m, /*cols=*/spec.k)));
Shape rhs0_shape = ShapeUtil::MakeShape(F32, {k0, spec.n});
@@ -2783,7 +2885,7 @@ TEST_P(DotOfConcatSimplificationTest, ConstantRHS) {
Shape rhs_shape = ShapeUtil::MakeShape(F32, {spec.k, spec.n});
auto* rhs = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
/*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.k, /*cols=*/spec.n)));
DotDimensionNumbers dot_dnums;
@@ -2830,7 +2932,7 @@ TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceZeroUpdate) {
HloInstruction* const update = builder.AddInstruction(
HloInstruction::CreateParameter(1, update_shape, "update"));
HloInstruction* const start_indices = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int>({0})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int>({0})));
builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
dslice_shape, operand, update, start_indices));
const HloComputation* const computation =
@@ -2879,7 +2981,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) {
int64 lhs_cols = (spec.lcd == 0) ? spec.m : (spec.k + k_increase);
Shape lhs_shape = ShapeUtil::MakeShape(F32, {lhs_rows, lhs_cols});
auto* lhs = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
/*from=*/10.0, /*to=*/10000.0, /*rows=*/lhs_rows,
/*cols=*/lhs_cols)));
@@ -2887,7 +2989,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) {
int32 start_col = (spec.lcd == 0) ? spec.s : 0;
const auto start_indices =
builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<int32>({start_row, start_col})));
+ LiteralUtil::CreateR1<int32>({start_row, start_col})));
int64 slice_row_size = (spec.lcd == 0) ? spec.k : 1;
int64 slice_col_size = (spec.lcd == 0) ? 1 : spec.k;
Shape ds_shape = ShapeUtil::MakeShape(F32, {slice_row_size, slice_col_size});
@@ -2898,7 +3000,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) {
int64 rhs_cols = (spec.rcd == 0) ? spec.n : spec.k;
Shape rhs_shape = ShapeUtil::MakeShape(F32, {rhs_rows, rhs_cols});
auto* rhs = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
/*from=*/10.0, /*to=*/10000.0, /*rows=*/rhs_rows,
/*cols=*/rhs_cols)));
@@ -2946,7 +3048,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) {
int64 lhs_cols = (spec.lcd == 0) ? spec.m : spec.k;
Shape lhs_shape = ShapeUtil::MakeShape(F32, {lhs_rows, lhs_cols});
auto* lhs = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
/*from=*/10.0, /*to=*/10000.0, /*rows=*/lhs_rows,
/*cols=*/lhs_cols)));
@@ -2957,7 +3059,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) {
int64 rhs_cols = (spec.rcd == 0) ? spec.n : (spec.k + k_increase);
Shape rhs_shape = ShapeUtil::MakeShape(F32, {rhs_rows, rhs_cols});
auto* rhs = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
/*from=*/10.0, /*to=*/10000.0, /*rows=*/rhs_rows,
/*cols=*/rhs_cols)));
@@ -2965,7 +3067,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) {
int32 start_col = (spec.rcd == 0) ? spec.s : 0;
const auto start_indices =
builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<int32>({start_row, start_col})));
+ LiteralUtil::CreateR1<int32>({start_row, start_col})));
int64 slice_row_size = (spec.rcd == 0) ? spec.k : 1;
int64 slice_col_size = (spec.rcd == 0) ? 1 : spec.k;
Shape ds_shape = ShapeUtil::MakeShape(F32, {slice_row_size, slice_col_size});