diff options
author | 2017-05-02 16:27:12 -0800 | |
---|---|---|
committer | 2017-05-02 17:49:26 -0700 | |
commit | 58196d4bf923d6fa2500e84d9d22ed8227ba305c (patch) | |
tree | 8e00cc8683614dc45306152ef56cedf9c7c9f93d /tensorflow/compiler/xla | |
parent | a5749019e065b25f49531de8b9f29627fb12fc5f (diff) |
[TF:XLA] Added unittest for transpose constant folding
Transpose constant folding was missing a unittest.
Change: 154903586
Diffstat (limited to 'tensorflow/compiler/xla')
-rw-r--r-- | tensorflow/compiler/xla/literal_util.h | 46 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_constant_folding_test.cc | 78 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/literal_test_util.h | 65 |
4 files changed, 173 insertions, 17 deletions
diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index ae3d43e56c..3a6d21979e 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -339,6 +340,14 @@ class LiteralUtil { const Layout& layout, Literal* literal); + // Populates literal values by calling the generator function for every cell + // in the literal object. + template <typename NativeT> + static Status Populate( + Literal* literal, + const std::function<NativeT(tensorflow::gtl::ArraySlice<int64> indexes)>& + generator); + // Creates a Literal of the given dimensions with all elements set to the // given value. template <typename NativeT> @@ -993,6 +1002,43 @@ template <typename NativeT> } template <typename NativeT> +/* static */ Status LiteralUtil::Populate( + Literal* literal, + const std::function<NativeT(tensorflow::gtl::ArraySlice<int64> indexes)>& + generator) { + const Shape& shape = literal->shape(); + int64 rank = ShapeUtil::Rank(shape); + TF_RET_CHECK(shape.element_type() == + primitive_util::NativeToPrimitiveType<NativeT>()); + tensorflow::protobuf::RepeatedField<NativeT>* data = + GetMutableRepeatedField<NativeT>(literal); + if (rank > 0) { + std::vector<int64> base(rank, 0); + std::vector<int64> step(rank, 1); + std::vector<int64> minor_scan_indexes(rank, 0); + int64 minor_dimension = shape.layout().minor_to_major()[0]; + int64 minor_dimension_size = + ShapeUtil::GetDimension(shape, minor_dimension); + + step[minor_dimension] = minor_dimension_size; + auto init_function = [&](const std::vector<int64>& indexes) { + int64 index = LinearIndex(*literal, indexes); + std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin()); + for (int64 i = 0; i < minor_dimension_size; ++i) { + minor_scan_indexes[minor_dimension] = i; + data->Set(index + i, generator(minor_scan_indexes)); + } + return true; + }; + ShapeUtil::ForEachIndex(shape, base, AsInt64Slice(shape.dimensions()), step, + init_function); + } else { + data->Set(0, generator({})); + } + return Status::OK(); +} + +template <typename NativeT> /* static */ void LiteralUtil::PopulateWithValue( NativeT value, tensorflow::gtl::ArraySlice<int64> dimensions, Literal* literal) { diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index bdb69b6e55..750e1ee3f2 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1436,6 +1436,7 @@ cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", "//tensorflow/core:test_main", ], diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc index d20f423bd6..21d93a1f27 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/types.h" namespace op = xla::testing::opcode_matchers; @@ -49,8 +50,9 @@ TEST_F(HloConstantFoldingTest, ConvertF32ToS64) { EXPECT_THAT(computation->root_instruction(), op::Convert(input)); - HloConstantFolding simplifier; - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + HloConstantFolding const_folder; + TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get())); + EXPECT_TRUE(result); EXPECT_THAT(computation->root_instruction(), op::Constant()); EXPECT_EQ(LiteralUtil::GetFirstElement<int64>( @@ -70,8 +72,9 @@ TEST_F(HloConstantFoldingTest, ConvertS64ToF32) { EXPECT_THAT(computation->root_instruction(), op::Convert(input)); - HloConstantFolding simplifier; - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + HloConstantFolding const_folder; + TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get())); + EXPECT_TRUE(result); EXPECT_THAT(computation->root_instruction(), op::Constant()); EXPECT_EQ(LiteralUtil::GetFirstElement<float>( @@ -91,8 +94,9 @@ TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) { EXPECT_THAT(computation->root_instruction(), op::Convert(input)); - HloConstantFolding simplifier; - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + HloConstantFolding const_folder; + TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get())); + EXPECT_TRUE(result); EXPECT_THAT(computation->root_instruction(), op::Constant()); EXPECT_EQ( @@ -131,11 +135,12 @@ TEST_F(HloConstantFoldingTest, Concatenate) { Shape shape = ShapeUtil::MakeShape(F32, dimensions); builder.AddInstruction(HloInstruction::CreateConcatenate( shape, operands, test_config.concat_dimension)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = MakeUnique<HloModule>(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); - HloConstantFolding simplifier; - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + HloConstantFolding const_folder; + TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get())); + EXPECT_TRUE(result); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Constant()); @@ -148,21 +153,60 @@ TEST_F(HloConstantFoldingTest, Slice) { const int64 dimensions[] = {11, 8, 7, 5, 9}; const int64 slice_start[] = {4, 2, 3, 1, 5}; const int64 slice_limits[] = {10, 8, 6, 5, 9}; - auto literal = LiteralUtil::CreateFromDimensions(F32, dimensions); - HloInstruction* lit_insn = builder.AddInstruction( + TF_ASSIGN_OR_ASSERT_OK(auto literal, + LiteralTestUtil::CreateRandomLiteral<F32>( + ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); + HloInstruction* literal_instruction = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); Shape shape = ShapeUtil::MakeShape(F32, {6, 6, 3, 4, 4}); + builder.AddInstruction(HloInstruction::CreateSlice( + shape, literal_instruction, slice_start, slice_limits)); + auto module = MakeUnique<HloModule>(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + + HloConstantFolding const_folder; + TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get())); + EXPECT_TRUE(result); + + HloInstruction* root = computation->root_instruction(); + EXPECT_THAT(root, op::Constant()); + EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape)); +} + +TEST_F(HloConstantFoldingTest, TransposeConstantFold) { + HloComputation::Builder builder(TestName()); + const int64 dimensions[] = {11, 8, 7, 5, 9}; + TF_ASSIGN_OR_ASSERT_OK(auto literal, + LiteralTestUtil::CreateRandomLiteral<F32>( + ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); + auto literal_clone = LiteralUtil::CloneToUnique(*literal); + HloInstruction* literal_instruction = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(literal))); + Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5}); + const int64 permutation[] = {1, 2, 0, 4, 3}; builder.AddInstruction( - HloInstruction::CreateSlice(shape, lit_insn, slice_start, slice_limits)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + HloInstruction::CreateTranspose(shape, literal_instruction, permutation)); + auto module = MakeUnique<HloModule>(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); - HloConstantFolding simplifier; - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + HloConstantFolding const_folder; + TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get())); + EXPECT_TRUE(result); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Constant()); EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape)); + + using NativeT = typename primitive_util::PrimitiveTypeToNative<F32>::type; + bool matched = true; + LiteralUtil::EachCell<NativeT>( + root->literal(), + [&](tensorflow::gtl::ArraySlice<int64> indices, NativeT value) { + std::vector<int64> rindexes = Permute(permutation, indices); + matched = matched && (value == LiteralUtil::Get<NativeT>(*literal_clone, + rindexes)); + }); + EXPECT_TRUE(matched); } } // namespace diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h index aeadc023cc..4f98083033 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.h +++ b/tensorflow/compiler/xla/tests/literal_test_util.h @@ -18,6 +18,7 @@ limitations under the License. #include <initializer_list> #include <memory> +#include <random> #include <string> #include "tensorflow/compiler/xla/array2d.h" @@ -171,6 +172,36 @@ class LiteralTestUtil { tensorflow::gtl::ArraySlice<int64> minor_to_major, const Literal& literal); + // Creates a literal with the supplied shape, and uses the provided value + // generator to populate the literal's values. + // Returns the new literal object, or an error Status if failed. + template < + PrimitiveType type, + typename T = typename primitive_util::PrimitiveTypeToNative<type>::type> + static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral( + const Shape& shape, + const std::function<T(tensorflow::gtl::ArraySlice<int64>)>& generator); + + // Creates a literal with the supplied shape, and initializes the literal + // values using a normal distribution with given mean and stddev standard + // deviation, and using the engine as entropy generator. + // Returns the new literal object, or an error Status if failed. + template < + PrimitiveType type, typename E, + typename T = typename primitive_util::PrimitiveTypeToNative<type>::type> + static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral( + const Shape& shape, E* engine, T mean, T stddev); + + // Creates a literal with the supplied shape, and initializes the literal + // values using a normal distribution with given mean and stddev standard + // deviation. + // Returns the new literal object, or an error Status if failed. + template < + PrimitiveType type, + typename T = typename primitive_util::PrimitiveTypeToNative<type>::type> + static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral( + const Shape& shape, T mean, T stddev); + private: TF_DISALLOW_COPY_AND_ASSIGN(LiteralTestUtil); }; @@ -270,6 +301,40 @@ template <typename NativeT> ExpectNear(*LiteralUtil::CreateR4FromArray4D(expected), actual, error); } +template <PrimitiveType type, typename T> +/* static */ StatusOr<std::unique_ptr<Literal>> +LiteralTestUtil::CreateRandomLiteral( + const Shape& shape, + const std::function<T(tensorflow::gtl::ArraySlice<int64>)>& generator) { + using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type; + TF_RET_CHECK(shape.element_type() == type); + std::unique_ptr<Literal> literal = LiteralUtil::CreateFromShape(shape); + TF_RETURN_IF_ERROR(LiteralUtil::Populate<NativeT>( + literal.get(), [&](tensorflow::gtl::ArraySlice<int64> indexes) { + return generator(indexes); + })); + return std::move(literal); +} + +template <PrimitiveType type, typename E, typename T> +/* static */ StatusOr<std::unique_ptr<Literal>> +LiteralTestUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean, + T stddev) { + using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type; + std::normal_distribution<NativeT> generator(mean, stddev); + return CreateRandomLiteral<type, NativeT>( + shape, [&](tensorflow::gtl::ArraySlice<int64> /*indexes*/) { + return generator(*engine); + }); +} + +template <PrimitiveType type, typename T> +/* static */ StatusOr<std::unique_ptr<Literal>> +LiteralTestUtil::CreateRandomLiteral(const Shape& shape, T mean, T stddev) { + std::minstd_rand0 engine; + return CreateRandomLiteral<type>(shape, &engine, mean, stddev); +} + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_TESTS_LITERAL_TEST_UTIL_H_ |