aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-05-02 16:27:12 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-02 17:49:26 -0700
commit58196d4bf923d6fa2500e84d9d22ed8227ba305c (patch)
tree8e00cc8683614dc45306152ef56cedf9c7c9f93d /tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
parenta5749019e065b25f49531de8b9f29627fb12fc5f (diff)
[TF:XLA] Added unittest for transpose constant folding
Transpose constant folding was missing a unittest. Change: 154903586
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_constant_folding_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_constant_folding_test.cc78
1 files changed, 61 insertions, 17 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
index d20f423bd6..21d93a1f27 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/types.h"
namespace op = xla::testing::opcode_matchers;
@@ -49,8 +50,9 @@ TEST_F(HloConstantFoldingTest, ConvertF32ToS64) {
EXPECT_THAT(computation->root_instruction(), op::Convert(input));
- HloConstantFolding simplifier;
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ HloConstantFolding const_folder;
+ TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get()));
+ EXPECT_TRUE(result);
EXPECT_THAT(computation->root_instruction(), op::Constant());
EXPECT_EQ(LiteralUtil::GetFirstElement<int64>(
@@ -70,8 +72,9 @@ TEST_F(HloConstantFoldingTest, ConvertS64ToF32) {
EXPECT_THAT(computation->root_instruction(), op::Convert(input));
- HloConstantFolding simplifier;
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ HloConstantFolding const_folder;
+ TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get()));
+ EXPECT_TRUE(result);
EXPECT_THAT(computation->root_instruction(), op::Constant());
EXPECT_EQ(LiteralUtil::GetFirstElement<float>(
@@ -91,8 +94,9 @@ TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) {
EXPECT_THAT(computation->root_instruction(), op::Convert(input));
- HloConstantFolding simplifier;
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ HloConstantFolding const_folder;
+ TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get()));
+ EXPECT_TRUE(result);
EXPECT_THAT(computation->root_instruction(), op::Constant());
EXPECT_EQ(
@@ -131,11 +135,12 @@ TEST_F(HloConstantFoldingTest, Concatenate) {
Shape shape = ShapeUtil::MakeShape(F32, dimensions);
builder.AddInstruction(HloInstruction::CreateConcatenate(
shape, operands, test_config.concat_dimension));
- HloModule module(TestName());
- auto computation = module.AddEntryComputation(builder.Build());
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
- HloConstantFolding simplifier;
- ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ HloConstantFolding const_folder;
+ TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get()));
+ EXPECT_TRUE(result);
HloInstruction* root = computation->root_instruction();
EXPECT_THAT(root, op::Constant());
@@ -148,21 +153,60 @@ TEST_F(HloConstantFoldingTest, Slice) {
const int64 dimensions[] = {11, 8, 7, 5, 9};
const int64 slice_start[] = {4, 2, 3, 1, 5};
const int64 slice_limits[] = {10, 8, 6, 5, 9};
- auto literal = LiteralUtil::CreateFromDimensions(F32, dimensions);
- HloInstruction* lit_insn = builder.AddInstruction(
+ TF_ASSIGN_OR_ASSERT_OK(auto literal,
+ LiteralTestUtil::CreateRandomLiteral<F32>(
+ ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
+ HloInstruction* literal_instruction = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
Shape shape = ShapeUtil::MakeShape(F32, {6, 6, 3, 4, 4});
+ builder.AddInstruction(HloInstruction::CreateSlice(
+ shape, literal_instruction, slice_start, slice_limits));
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ HloConstantFolding const_folder;
+ TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get()));
+ EXPECT_TRUE(result);
+
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_THAT(root, op::Constant());
+ EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape));
+}
+
+TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
+ HloComputation::Builder builder(TestName());
+ const int64 dimensions[] = {11, 8, 7, 5, 9};
+ TF_ASSIGN_OR_ASSERT_OK(auto literal,
+ LiteralTestUtil::CreateRandomLiteral<F32>(
+ ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
+ auto literal_clone = LiteralUtil::CloneToUnique(*literal);
+ HloInstruction* literal_instruction = builder.AddInstruction(
+ HloInstruction::CreateConstant(std::move(literal)));
+ Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5});
+ const int64 permutation[] = {1, 2, 0, 4, 3};
builder.AddInstruction(
- HloInstruction::CreateSlice(shape, lit_insn, slice_start, slice_limits));
- HloModule module(TestName());
- auto computation = module.AddEntryComputation(builder.Build());
+ HloInstruction::CreateTranspose(shape, literal_instruction, permutation));
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
- HloConstantFolding simplifier;
- ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ HloConstantFolding const_folder;
+ TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get()));
+ EXPECT_TRUE(result);
HloInstruction* root = computation->root_instruction();
EXPECT_THAT(root, op::Constant());
EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape));
+
+ using NativeT = typename primitive_util::PrimitiveTypeToNative<F32>::type;
+ bool matched = true;
+ LiteralUtil::EachCell<NativeT>(
+ root->literal(),
+ [&](tensorflow::gtl::ArraySlice<int64> indices, NativeT value) {
+ std::vector<int64> rindexes = Permute(permutation, indices);
+ matched = matched && (value == LiteralUtil::Get<NativeT>(*literal_clone,
+ rindexes));
+ });
+ EXPECT_TRUE(matched);
}
} // namespace