aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/algebraic_simplifier_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc133
1 files changed, 0 insertions, 133 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 0cce076da5..f4b42055b7 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -466,75 +466,6 @@ TEST_F(AlgebraicSimplifierTest, ConvertBetweenSameType) {
EXPECT_THAT(computation->root_instruction(), input);
}
-TEST_F(AlgebraicSimplifierTest, ConvertF32ToS64) {
- HloComputation::Builder builder(TestName());
- HloInstruction* input = builder.AddInstruction(
- HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
- builder.AddInstruction(
- HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {}), input));
-
- auto module = MakeUnique<HloModule>(TestName());
- auto computation = module->AddEntryComputation(builder.Build());
-
- EXPECT_THAT(computation->root_instruction(), op::Convert(input));
-
- AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
- non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
-
- EXPECT_THAT(computation->root_instruction(), op::Constant());
- EXPECT_EQ(LiteralUtil::GetFirstElement<int64>(
- computation->root_instruction()->literal()),
- 42);
-}
-
-TEST_F(AlgebraicSimplifierTest, ConvertS64ToF32) {
- HloComputation::Builder builder(TestName());
- HloInstruction* input = builder.AddInstruction(
- HloInstruction::CreateConstant(LiteralUtil::CreateR0<int64>(42)));
- builder.AddInstruction(
- HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input));
-
- auto module = MakeUnique<HloModule>(TestName());
- auto computation = module->AddEntryComputation(builder.Build());
-
- EXPECT_THAT(computation->root_instruction(), op::Convert(input));
-
- AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
- non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
-
- EXPECT_THAT(computation->root_instruction(), op::Constant());
- EXPECT_EQ(LiteralUtil::GetFirstElement<float>(
- computation->root_instruction()->literal()),
- 42.0f);
-}
-
-TEST_F(AlgebraicSimplifierTest, ConvertF32ArrayToS64Array) {
- HloComputation::Builder builder(TestName());
- HloInstruction* input = builder.AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::CreateR1<float>({42.0f, 19.0f})));
- builder.AddInstruction(
- HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {2}), input));
-
- auto module = MakeUnique<HloModule>(TestName());
- auto computation = module->AddEntryComputation(builder.Build());
-
- EXPECT_THAT(computation->root_instruction(), op::Convert(input));
-
- AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
- non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
-
- EXPECT_THAT(computation->root_instruction(), op::Constant());
- EXPECT_EQ(
- LiteralUtil::Get<int64>(computation->root_instruction()->literal(), {0}),
- 42);
- EXPECT_EQ(
- LiteralUtil::Get<int64>(computation->root_instruction()->literal(), {1}),
- 19);
-}
-
// Test that copies are removed.
TEST_F(AlgebraicSimplifierTest, RemoveCopy) {
Shape r0f32 = ShapeUtil::MakeShape(F32, {});
@@ -1666,69 +1597,5 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) {
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
}
-TEST_F(AlgebraicSimplifierTest, Concatenate) {
- const struct TestConfig {
- int concat_dimension;
- tensorflow::gtl::ArraySlice<int64> dimensions;
- tensorflow::gtl::ArraySlice<int64> concat_sizes;
- } test_configs[] = {
- {1, {11, 0, 7, 5, 9}, {2, 5, 7, 11}},
- {3, {1, 4, 17, 0, 8}, {1, 3, 9, 12}},
- };
-
- for (auto& test_config : test_configs) {
- HloComputation::Builder builder(TestName());
- std::vector<int64> dimensions(test_config.dimensions.begin(),
- test_config.dimensions.end());
- int64 concat_size = 0;
- std::vector<HloInstruction*> operands;
- for (auto csize : test_config.concat_sizes) {
- dimensions[test_config.concat_dimension] = csize;
- concat_size += csize;
- auto literal = LiteralUtil::CreateFromDimensions(F32, dimensions);
- HloInstruction* insn = builder.AddInstruction(
- HloInstruction::CreateConstant(std::move(literal)));
- operands.push_back(insn);
- }
- dimensions[test_config.concat_dimension] = concat_size;
- Shape shape = ShapeUtil::MakeShape(F32, dimensions);
- builder.AddInstruction(HloInstruction::CreateConcatenate(
- shape, operands, test_config.concat_dimension));
- HloModule module(TestName());
- auto computation = module.AddEntryComputation(builder.Build());
-
- AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
- non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
-
- HloInstruction* root = computation->root_instruction();
- EXPECT_EQ(root->opcode(), HloOpcode::kConstant);
- EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape));
- }
-}
-
-TEST_F(AlgebraicSimplifierTest, Slice) {
- HloComputation::Builder builder(TestName());
- const int64 dimensions[] = {11, 8, 7, 5, 9};
- const int64 slice_start[] = {4, 2, 3, 1, 5};
- const int64 slice_limits[] = {10, 8, 6, 5, 9};
- auto literal = LiteralUtil::CreateFromDimensions(F32, dimensions);
- HloInstruction* lit_insn = builder.AddInstruction(
- HloInstruction::CreateConstant(std::move(literal)));
- Shape shape = ShapeUtil::MakeShape(F32, {6, 6, 3, 4, 4});
- builder.AddInstruction(
- HloInstruction::CreateSlice(shape, lit_insn, slice_start, slice_limits));
- HloModule module(TestName());
- auto computation = module.AddEntryComputation(builder.Build());
-
- AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
- non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
-
- HloInstruction* root = computation->root_instruction();
- EXPECT_EQ(root->opcode(), HloOpcode::kConstant);
- EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape));
-}
-
} // namespace
} // namespace xla