diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/algebraic_simplifier_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/algebraic_simplifier_test.cc | 133 |
1 files changed, 0 insertions, 133 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 0cce076da5..f4b42055b7 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -466,75 +466,6 @@ TEST_F(AlgebraicSimplifierTest, ConvertBetweenSameType) { EXPECT_THAT(computation->root_instruction(), input); } -TEST_F(AlgebraicSimplifierTest, ConvertF32ToS64) { - HloComputation::Builder builder(TestName()); - HloInstruction* input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f))); - builder.AddInstruction( - HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {}), input)); - - auto module = MakeUnique<HloModule>(TestName()); - auto computation = module->AddEntryComputation(builder.Build()); - - EXPECT_THAT(computation->root_instruction(), op::Convert(input)); - - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - - EXPECT_THAT(computation->root_instruction(), op::Constant()); - EXPECT_EQ(LiteralUtil::GetFirstElement<int64>( - computation->root_instruction()->literal()), - 42); -} - -TEST_F(AlgebraicSimplifierTest, ConvertS64ToF32) { - HloComputation::Builder builder(TestName()); - HloInstruction* input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0<int64>(42))); - builder.AddInstruction( - HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input)); - - auto module = MakeUnique<HloModule>(TestName()); - auto computation = module->AddEntryComputation(builder.Build()); - - EXPECT_THAT(computation->root_instruction(), op::Convert(input)); - - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - - EXPECT_THAT(computation->root_instruction(), op::Constant()); - EXPECT_EQ(LiteralUtil::GetFirstElement<float>( - computation->root_instruction()->literal()), - 42.0f); -} - -TEST_F(AlgebraicSimplifierTest, ConvertF32ArrayToS64Array) { - HloComputation::Builder builder(TestName()); - HloInstruction* input = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1<float>({42.0f, 19.0f}))); - builder.AddInstruction( - HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {2}), input)); - - auto module = MakeUnique<HloModule>(TestName()); - auto computation = module->AddEntryComputation(builder.Build()); - - EXPECT_THAT(computation->root_instruction(), op::Convert(input)); - - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - - EXPECT_THAT(computation->root_instruction(), op::Constant()); - EXPECT_EQ( - LiteralUtil::Get<int64>(computation->root_instruction()->literal(), {0}), - 42); - EXPECT_EQ( - LiteralUtil::Get<int64>(computation->root_instruction()->literal(), {1}), - 19); -} - // Test that copies are removed. TEST_F(AlgebraicSimplifierTest, RemoveCopy) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); @@ -1666,69 +1597,5 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) { ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); } -TEST_F(AlgebraicSimplifierTest, Concatenate) { - const struct TestConfig { - int concat_dimension; - tensorflow::gtl::ArraySlice<int64> dimensions; - tensorflow::gtl::ArraySlice<int64> concat_sizes; - } test_configs[] = { - {1, {11, 0, 7, 5, 9}, {2, 5, 7, 11}}, - {3, {1, 4, 17, 0, 8}, {1, 3, 9, 12}}, - }; - - for (auto& test_config : test_configs) { - HloComputation::Builder builder(TestName()); - std::vector<int64> dimensions(test_config.dimensions.begin(), - test_config.dimensions.end()); - int64 concat_size = 0; - std::vector<HloInstruction*> operands; - for (auto csize : test_config.concat_sizes) { - dimensions[test_config.concat_dimension] = csize; - concat_size += csize; - auto literal = LiteralUtil::CreateFromDimensions(F32, dimensions); - HloInstruction* insn = builder.AddInstruction( - HloInstruction::CreateConstant(std::move(literal))); - operands.push_back(insn); - } - dimensions[test_config.concat_dimension] = concat_size; - Shape shape = ShapeUtil::MakeShape(F32, dimensions); - builder.AddInstruction(HloInstruction::CreateConcatenate( - shape, operands, test_config.concat_dimension)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); - - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); - - HloInstruction* root = computation->root_instruction(); - EXPECT_EQ(root->opcode(), HloOpcode::kConstant); - EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape)); - } -} - -TEST_F(AlgebraicSimplifierTest, Slice) { - HloComputation::Builder builder(TestName()); - const int64 dimensions[] = {11, 8, 7, 5, 9}; - const int64 slice_start[] = {4, 2, 3, 1, 5}; - const int64 slice_limits[] = {10, 8, 6, 5, 9}; - auto literal = LiteralUtil::CreateFromDimensions(F32, dimensions); - HloInstruction* lit_insn = builder.AddInstruction( - HloInstruction::CreateConstant(std::move(literal))); - Shape shape = ShapeUtil::MakeShape(F32, {6, 6, 3, 4, 4}); - builder.AddInstruction( - HloInstruction::CreateSlice(shape, lit_insn, slice_start, slice_limits)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); - - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); - - HloInstruction* root = computation->root_instruction(); - EXPECT_EQ(root->opcode(), HloOpcode::kConstant); - EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape)); -} - } // namespace } // namespace xla |