diff options
author | 2017-05-02 16:27:12 -0800 | |
---|---|---|
committer | 2017-05-02 17:49:26 -0700 | |
commit | 58196d4bf923d6fa2500e84d9d22ed8227ba305c (patch) | |
tree | 8e00cc8683614dc45306152ef56cedf9c7c9f93d /tensorflow/compiler/xla/service/hlo_constant_folding_test.cc | |
parent | a5749019e065b25f49531de8b9f29627fb12fc5f (diff) |
[TF:XLA] Added unittest for transpose constant folding
Transpose constant folding was missing a unittest.
Change: 154903586
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_constant_folding_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_constant_folding_test.cc | 78 |
1 files changed, 61 insertions, 17 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc index d20f423bd6..21d93a1f27 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/types.h" namespace op = xla::testing::opcode_matchers; @@ -49,8 +50,9 @@ TEST_F(HloConstantFoldingTest, ConvertF32ToS64) { EXPECT_THAT(computation->root_instruction(), op::Convert(input)); - HloConstantFolding simplifier; - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + HloConstantFolding const_folder; + TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get())); + EXPECT_TRUE(result); EXPECT_THAT(computation->root_instruction(), op::Constant()); EXPECT_EQ(LiteralUtil::GetFirstElement<int64>( @@ -70,8 +72,9 @@ TEST_F(HloConstantFoldingTest, ConvertS64ToF32) { EXPECT_THAT(computation->root_instruction(), op::Convert(input)); - HloConstantFolding simplifier; - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + HloConstantFolding const_folder; + TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get())); + EXPECT_TRUE(result); EXPECT_THAT(computation->root_instruction(), op::Constant()); EXPECT_EQ(LiteralUtil::GetFirstElement<float>( @@ -91,8 +94,9 @@ TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) { EXPECT_THAT(computation->root_instruction(), op::Convert(input)); - HloConstantFolding simplifier; - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + HloConstantFolding const_folder; + TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get())); + EXPECT_TRUE(result); EXPECT_THAT(computation->root_instruction(), op::Constant()); EXPECT_EQ( @@ -131,11 +135,12 @@ TEST_F(HloConstantFoldingTest, Concatenate) { Shape shape = ShapeUtil::MakeShape(F32, dimensions); builder.AddInstruction(HloInstruction::CreateConcatenate( shape, operands, test_config.concat_dimension)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = MakeUnique<HloModule>(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); - HloConstantFolding simplifier; - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + HloConstantFolding const_folder; + TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get())); + EXPECT_TRUE(result); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Constant()); @@ -148,21 +153,60 @@ TEST_F(HloConstantFoldingTest, Slice) { const int64 dimensions[] = {11, 8, 7, 5, 9}; const int64 slice_start[] = {4, 2, 3, 1, 5}; const int64 slice_limits[] = {10, 8, 6, 5, 9}; - auto literal = LiteralUtil::CreateFromDimensions(F32, dimensions); - HloInstruction* lit_insn = builder.AddInstruction( + TF_ASSIGN_OR_ASSERT_OK(auto literal, + LiteralTestUtil::CreateRandomLiteral<F32>( + ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); + HloInstruction* literal_instruction = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); Shape shape = ShapeUtil::MakeShape(F32, {6, 6, 3, 4, 4}); + builder.AddInstruction(HloInstruction::CreateSlice( + shape, literal_instruction, slice_start, slice_limits)); + auto module = MakeUnique<HloModule>(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + + HloConstantFolding const_folder; + TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get())); + EXPECT_TRUE(result); + + HloInstruction* root = computation->root_instruction(); + EXPECT_THAT(root, op::Constant()); + EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape)); +} + +TEST_F(HloConstantFoldingTest, TransposeConstantFold) { + HloComputation::Builder builder(TestName()); + const int64 dimensions[] = {11, 8, 7, 5, 9}; + TF_ASSIGN_OR_ASSERT_OK(auto literal, + LiteralTestUtil::CreateRandomLiteral<F32>( + ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); + auto literal_clone = LiteralUtil::CloneToUnique(*literal); + HloInstruction* literal_instruction = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(literal))); + Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5}); + const int64 permutation[] = {1, 2, 0, 4, 3}; builder.AddInstruction( - HloInstruction::CreateSlice(shape, lit_insn, slice_start, slice_limits)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + HloInstruction::CreateTranspose(shape, literal_instruction, permutation)); + auto module = MakeUnique<HloModule>(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); - HloConstantFolding simplifier; - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + HloConstantFolding const_folder; + TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get())); + EXPECT_TRUE(result); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Constant()); EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape)); + + using NativeT = typename primitive_util::PrimitiveTypeToNative<F32>::type; + bool matched = true; + LiteralUtil::EachCell<NativeT>( + root->literal(), + [&](tensorflow::gtl::ArraySlice<int64> indices, NativeT value) { + std::vector<int64> rindexes = Permute(permutation, indices); + matched = matched && (value == LiteralUtil::Get<NativeT>(*literal_clone, + rindexes)); + }); + EXPECT_TRUE(matched); } } // namespace |