aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-05-02 16:27:12 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-02 17:49:26 -0700
commit58196d4bf923d6fa2500e84d9d22ed8227ba305c (patch)
tree8e00cc8683614dc45306152ef56cedf9c7c9f93d /tensorflow/compiler/xla
parenta5749019e065b25f49531de8b9f29627fb12fc5f (diff)
[TF:XLA] Added unittest for transpose constant folding
Transpose constant folding was missing a unittest. Change: 154903586
Diffstat (limited to 'tensorflow/compiler/xla')
-rw-r--r--tensorflow/compiler/xla/literal_util.h46
-rw-r--r--tensorflow/compiler/xla/service/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/hlo_constant_folding_test.cc78
-rw-r--r--tensorflow/compiler/xla/tests/literal_test_util.h65
4 files changed, 173 insertions, 17 deletions
diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h
index ae3d43e56c..3a6d21979e 100644
--- a/tensorflow/compiler/xla/literal_util.h
+++ b/tensorflow/compiler/xla/literal_util.h
@@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -339,6 +340,14 @@ class LiteralUtil {
const Layout& layout,
Literal* literal);
+ // Populates literal values by calling the generator function for every cell
+ // in the literal object.
+ template <typename NativeT>
+ static Status Populate(
+ Literal* literal,
+ const std::function<NativeT(tensorflow::gtl::ArraySlice<int64> indexes)>&
+ generator);
+
// Creates a Literal of the given dimensions with all elements set to the
// given value.
template <typename NativeT>
@@ -993,6 +1002,43 @@ template <typename NativeT>
}
template <typename NativeT>
+/* static */ Status LiteralUtil::Populate(
+ Literal* literal,
+ const std::function<NativeT(tensorflow::gtl::ArraySlice<int64> indexes)>&
+ generator) {
+ const Shape& shape = literal->shape();
+ int64 rank = ShapeUtil::Rank(shape);
+ TF_RET_CHECK(shape.element_type() ==
+ primitive_util::NativeToPrimitiveType<NativeT>());
+ tensorflow::protobuf::RepeatedField<NativeT>* data =
+ GetMutableRepeatedField<NativeT>(literal);
+ if (rank > 0) {
+ std::vector<int64> base(rank, 0);
+ std::vector<int64> step(rank, 1);
+ std::vector<int64> minor_scan_indexes(rank, 0);
+ int64 minor_dimension = shape.layout().minor_to_major()[0];
+ int64 minor_dimension_size =
+ ShapeUtil::GetDimension(shape, minor_dimension);
+
+ step[minor_dimension] = minor_dimension_size;
+ auto init_function = [&](const std::vector<int64>& indexes) {
+ int64 index = LinearIndex(*literal, indexes);
+ std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin());
+ for (int64 i = 0; i < minor_dimension_size; ++i) {
+ minor_scan_indexes[minor_dimension] = i;
+ data->Set(index + i, generator(minor_scan_indexes));
+ }
+ return true;
+ };
+ ShapeUtil::ForEachIndex(shape, base, AsInt64Slice(shape.dimensions()), step,
+ init_function);
+ } else {
+ data->Set(0, generator({}));
+ }
+ return Status::OK();
+}
+
+template <typename NativeT>
/* static */ void LiteralUtil::PopulateWithValue(
NativeT value, tensorflow::gtl::ArraySlice<int64> dimensions,
Literal* literal) {
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index bdb69b6e55..750e1ee3f2 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -1436,6 +1436,7 @@ cc_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/core:lib",
"//tensorflow/core:test_main",
],
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
index d20f423bd6..21d93a1f27 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/types.h"
namespace op = xla::testing::opcode_matchers;
@@ -49,8 +50,9 @@ TEST_F(HloConstantFoldingTest, ConvertF32ToS64) {
EXPECT_THAT(computation->root_instruction(), op::Convert(input));
- HloConstantFolding simplifier;
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ HloConstantFolding const_folder;
+ TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get()));
+ EXPECT_TRUE(result);
EXPECT_THAT(computation->root_instruction(), op::Constant());
EXPECT_EQ(LiteralUtil::GetFirstElement<int64>(
@@ -70,8 +72,9 @@ TEST_F(HloConstantFoldingTest, ConvertS64ToF32) {
EXPECT_THAT(computation->root_instruction(), op::Convert(input));
- HloConstantFolding simplifier;
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ HloConstantFolding const_folder;
+ TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get()));
+ EXPECT_TRUE(result);
EXPECT_THAT(computation->root_instruction(), op::Constant());
EXPECT_EQ(LiteralUtil::GetFirstElement<float>(
@@ -91,8 +94,9 @@ TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) {
EXPECT_THAT(computation->root_instruction(), op::Convert(input));
- HloConstantFolding simplifier;
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ HloConstantFolding const_folder;
+ TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get()));
+ EXPECT_TRUE(result);
EXPECT_THAT(computation->root_instruction(), op::Constant());
EXPECT_EQ(
@@ -131,11 +135,12 @@ TEST_F(HloConstantFoldingTest, Concatenate) {
Shape shape = ShapeUtil::MakeShape(F32, dimensions);
builder.AddInstruction(HloInstruction::CreateConcatenate(
shape, operands, test_config.concat_dimension));
- HloModule module(TestName());
- auto computation = module.AddEntryComputation(builder.Build());
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
- HloConstantFolding simplifier;
- ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ HloConstantFolding const_folder;
+ TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get()));
+ EXPECT_TRUE(result);
HloInstruction* root = computation->root_instruction();
EXPECT_THAT(root, op::Constant());
@@ -148,21 +153,60 @@ TEST_F(HloConstantFoldingTest, Slice) {
const int64 dimensions[] = {11, 8, 7, 5, 9};
const int64 slice_start[] = {4, 2, 3, 1, 5};
const int64 slice_limits[] = {10, 8, 6, 5, 9};
- auto literal = LiteralUtil::CreateFromDimensions(F32, dimensions);
- HloInstruction* lit_insn = builder.AddInstruction(
+ TF_ASSIGN_OR_ASSERT_OK(auto literal,
+ LiteralTestUtil::CreateRandomLiteral<F32>(
+ ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
+ HloInstruction* literal_instruction = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
Shape shape = ShapeUtil::MakeShape(F32, {6, 6, 3, 4, 4});
+ builder.AddInstruction(HloInstruction::CreateSlice(
+ shape, literal_instruction, slice_start, slice_limits));
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ HloConstantFolding const_folder;
+ TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get()));
+ EXPECT_TRUE(result);
+
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_THAT(root, op::Constant());
+ EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape));
+}
+
+TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
+ HloComputation::Builder builder(TestName());
+ const int64 dimensions[] = {11, 8, 7, 5, 9};
+ TF_ASSIGN_OR_ASSERT_OK(auto literal,
+ LiteralTestUtil::CreateRandomLiteral<F32>(
+ ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
+ auto literal_clone = LiteralUtil::CloneToUnique(*literal);
+ HloInstruction* literal_instruction = builder.AddInstruction(
+ HloInstruction::CreateConstant(std::move(literal)));
+ Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5});
+ const int64 permutation[] = {1, 2, 0, 4, 3};
builder.AddInstruction(
- HloInstruction::CreateSlice(shape, lit_insn, slice_start, slice_limits));
- HloModule module(TestName());
- auto computation = module.AddEntryComputation(builder.Build());
+ HloInstruction::CreateTranspose(shape, literal_instruction, permutation));
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
- HloConstantFolding simplifier;
- ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ HloConstantFolding const_folder;
+ TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get()));
+ EXPECT_TRUE(result);
HloInstruction* root = computation->root_instruction();
EXPECT_THAT(root, op::Constant());
EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape));
+
+ using NativeT = typename primitive_util::PrimitiveTypeToNative<F32>::type;
+ bool matched = true;
+ LiteralUtil::EachCell<NativeT>(
+ root->literal(),
+ [&](tensorflow::gtl::ArraySlice<int64> indices, NativeT value) {
+ std::vector<int64> rindexes = Permute(permutation, indices);
+ matched = matched && (value == LiteralUtil::Get<NativeT>(*literal_clone,
+ rindexes));
+ });
+ EXPECT_TRUE(matched);
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h
index aeadc023cc..4f98083033 100644
--- a/tensorflow/compiler/xla/tests/literal_test_util.h
+++ b/tensorflow/compiler/xla/tests/literal_test_util.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <initializer_list>
#include <memory>
+#include <random>
#include <string>
#include "tensorflow/compiler/xla/array2d.h"
@@ -171,6 +172,36 @@ class LiteralTestUtil {
tensorflow::gtl::ArraySlice<int64> minor_to_major,
const Literal& literal);
+ // Creates a literal with the supplied shape, and uses the provided value
+ // generator to populate the literal's values.
+ // Returns the new literal object, or an error Status if failed.
+ template <
+ PrimitiveType type,
+ typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
+ static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
+ const Shape& shape,
+ const std::function<T(tensorflow::gtl::ArraySlice<int64>)>& generator);
+
+ // Creates a literal with the supplied shape, and initializes the literal
+ // values using a normal distribution with given mean and stddev standard
+ // deviation, and using the engine as entropy generator.
+ // Returns the new literal object, or an error Status if failed.
+ template <
+ PrimitiveType type, typename E,
+ typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
+ static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
+ const Shape& shape, E* engine, T mean, T stddev);
+
+ // Creates a literal with the supplied shape, and initializes the literal
+ // values using a normal distribution with given mean and stddev standard
+ // deviation.
+ // Returns the new literal object, or an error Status if failed.
+ template <
+ PrimitiveType type,
+ typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
+ static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
+ const Shape& shape, T mean, T stddev);
+
private:
TF_DISALLOW_COPY_AND_ASSIGN(LiteralTestUtil);
};
@@ -270,6 +301,40 @@ template <typename NativeT>
ExpectNear(*LiteralUtil::CreateR4FromArray4D(expected), actual, error);
}
+template <PrimitiveType type, typename T>
+/* static */ StatusOr<std::unique_ptr<Literal>>
+LiteralTestUtil::CreateRandomLiteral(
+ const Shape& shape,
+ const std::function<T(tensorflow::gtl::ArraySlice<int64>)>& generator) {
+ using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
+ TF_RET_CHECK(shape.element_type() == type);
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateFromShape(shape);
+ TF_RETURN_IF_ERROR(LiteralUtil::Populate<NativeT>(
+ literal.get(), [&](tensorflow::gtl::ArraySlice<int64> indexes) {
+ return generator(indexes);
+ }));
+ return std::move(literal);
+}
+
+template <PrimitiveType type, typename E, typename T>
+/* static */ StatusOr<std::unique_ptr<Literal>>
+LiteralTestUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean,
+ T stddev) {
+ using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
+ std::normal_distribution<NativeT> generator(mean, stddev);
+ return CreateRandomLiteral<type, NativeT>(
+ shape, [&](tensorflow::gtl::ArraySlice<int64> /*indexes*/) {
+ return generator(*engine);
+ });
+}
+
+template <PrimitiveType type, typename T>
+/* static */ StatusOr<std::unique_ptr<Literal>>
+LiteralTestUtil::CreateRandomLiteral(const Shape& shape, T mean, T stddev) {
+ std::minstd_rand0 engine;
+ return CreateRandomLiteral<type>(shape, &engine, mean, stddev);
+}
+
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_TESTS_LITERAL_TEST_UTIL_H_