aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2017-11-08 15:35:27 -0800
committerGravatar Andrew Selle <aselle@andyselle.com>2017-11-10 16:14:37 -0800
commit64d2636e2946772d4b1531ec91b389110a2787b7 (patch)
tree71fdb3d248e9b1f24ab649e1d21eac490e6782bd
parent2ba34173fad0d5b7d986baeb8171bdc6afdcd7bb (diff)
Move MakeFakeLiteral from client/lib/testing.h to tests/test_utils.h. Also remove superfluous literal creation methods in that file, and replace them with the existing ones in the Literal class.
Also, optionally print layout in Literal::ToString. PiperOrigin-RevId: 175076277
-rw-r--r--tensorflow/compiler/xla/client/lib/BUILD1
-rw-r--r--tensorflow/compiler/xla/client/lib/testing.cc57
-rw-r--r--tensorflow/compiler/xla/client/lib/testing.h4
-rw-r--r--tensorflow/compiler/xla/literal_util.cc22
-rw-r--r--tensorflow/compiler/xla/literal_util.h2
-rw-r--r--tensorflow/compiler/xla/service/BUILD2
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse_test.cc24
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment_test.cc32
-rw-r--r--tensorflow/compiler/xla/tests/BUILD3
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.h6
-rw-r--r--tensorflow/compiler/xla/tests/client_test.cc4
-rw-r--r--tensorflow/compiler/xla/tests/compilation_cache_test.cc8
-rw-r--r--tensorflow/compiler/xla/tests/compute_constant_test.cc4
-rw-r--r--tensorflow/compiler/xla/tests/dot_operation_test.cc25
-rw-r--r--tensorflow/compiler/xla/tests/local_client_execute_test.cc10
-rw-r--r--tensorflow/compiler/xla/tests/map_test.cc8
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc120
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.h64
-rw-r--r--tensorflow/compiler/xla/tools/BUILD1
-rw-r--r--tensorflow/compiler/xla/tools/replay_computation.cc1
20 files changed, 209 insertions, 189 deletions
diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD
index ee34682087..fca2bf2688 100644
--- a/tensorflow/compiler/xla/client/lib/BUILD
+++ b/tensorflow/compiler/xla/client/lib/BUILD
@@ -44,6 +44,7 @@ cc_library(
"//tensorflow/compiler/xla/client:computation",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:lib",
],
)
diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc
index e6645e4941..d936bd870b 100644
--- a/tensorflow/compiler/xla/client/lib/testing.cc
+++ b/tensorflow/compiler/xla/client/lib/testing.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/strings/strcat.h"
@@ -48,62 +49,6 @@ std::unique_ptr<GlobalData> MakeFakeDataViaDeviceOrDie(const Shape& shape,
} // namespace
-StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape) {
- if (ShapeUtil::IsTuple(shape)) {
- std::vector<std::unique_ptr<Literal>> elements;
- for (const Shape& element_shape : shape.tuple_shapes()) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> element,
- MakeFakeLiteral(element_shape));
- elements.push_back(std::move(element));
- }
- return Literal::MakeTupleOwned(std::move(elements));
- }
- std::unique_ptr<Literal> literal = Literal::CreateFromShape(shape);
- std::minstd_rand0 engine;
- switch (shape.element_type()) {
- case F32: {
- std::uniform_real_distribution<float> generator(0.0f, 1.0f);
- TF_CHECK_OK(literal->Populate<float>(
- [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
- return generator(engine);
- }));
- break;
- }
- case S32: {
- std::uniform_int_distribution<int32> generator(
- std::numeric_limits<int32>::lowest(),
- std::numeric_limits<int32>::max());
- TF_CHECK_OK(literal->Populate<int32>(
- [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
- return generator(engine);
- }));
- break;
- }
- case S64: {
- std::uniform_int_distribution<int64> generator(
- std::numeric_limits<int64>::lowest(),
- std::numeric_limits<int64>::max());
- TF_CHECK_OK(literal->Populate<int64>(
- [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
- return generator(engine);
- }));
- break;
- }
- case PRED: {
- std::uniform_int_distribution<int> generator(0, 1);
- TF_CHECK_OK(literal->Populate<bool>(
- [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
- return generator(engine);
- }));
- break;
- }
- default:
- return Unimplemented("Unsupported type for fake literal generation: %s",
- ShapeUtil::HumanString(shape).c_str());
- }
- return std::move(literal);
-}
-
std::unique_ptr<GlobalData> MakeFakeDataOrDie(const Shape& shape,
Client* client) {
if (ShapeUtil::ByteSizeOf(shape) < (1LL << 30)) {
diff --git a/tensorflow/compiler/xla/client/lib/testing.h b/tensorflow/compiler/xla/client/lib/testing.h
index b5c4393dcc..7e640d1307 100644
--- a/tensorflow/compiler/xla/client/lib/testing.h
+++ b/tensorflow/compiler/xla/client/lib/testing.h
@@ -26,10 +26,6 @@ limitations under the License.
namespace xla {
-// Generates fake data in a literal of the given shape, or returns an error
-// status if the element type is currently unhandled for fake data generation.
-StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape);
-
// Generates fake data of the given shape on the device or dies. The fake data
// is created by performing a computation on the device rather than transferring
// data from the host to the device.
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc
index fda791401d..0cb2223ae5 100644
--- a/tensorflow/compiler/xla/literal_util.cc
+++ b/tensorflow/compiler/xla/literal_util.cc
@@ -569,9 +569,17 @@ int64 Literal::LinearIndex(
return IndexUtil::MultidimensionalIndexToLinearIndex(shape(), multi_index);
}
-string Literal::ToString() const {
+string Literal::ToString(bool print_layout) const {
std::vector<string> pieces;
+ auto shape_to_string = [print_layout](const Shape& shape) {
+ if (print_layout) {
+ return ShapeUtil::HumanStringWithLayout(shape);
+ } else {
+ return ShapeUtil::HumanString(shape);
+ }
+ };
+
auto element_to_string =
[this](tensorflow::gtl::ArraySlice<int64> indices) -> string {
PrimitiveType element_type = shape().element_type();
@@ -585,7 +593,7 @@ string Literal::ToString() const {
// TODO(b/32894291): refactor this code to reduce code duplication.
if (ShapeUtil::IsTuple(shape())) {
- pieces.push_back(ShapeUtil::HumanString(shape()));
+ pieces.push_back(shape_to_string(shape()));
pieces.push_back(" (\n");
pieces.push_back(tensorflow::str_util::Join(
tuple_literals(), ",\n", [](string* out, const Literal& element) {
@@ -601,7 +609,7 @@ string Literal::ToString() const {
}
pieces.push_back("}");
} else if (ShapeUtil::Rank(shape()) == 2) {
- pieces.push_back(ShapeUtil::HumanString(shape()));
+ pieces.push_back(shape_to_string(shape()));
pieces.push_back(" {\n");
for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
pieces.push_back(" { ");
@@ -613,7 +621,7 @@ string Literal::ToString() const {
}
pieces.push_back("}");
} else if (ShapeUtil::Rank(shape()) == 3) {
- pieces.push_back(ShapeUtil::HumanString(shape()));
+ pieces.push_back(shape_to_string(shape()));
pieces.push_back(" {\n");
for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
pieces.push_back(i0 > 0 ? ",\n{" : "{");
@@ -628,7 +636,7 @@ string Literal::ToString() const {
}
pieces.push_back("\n}");
} else if (ShapeUtil::Rank(shape()) == 4) {
- pieces.push_back(ShapeUtil::HumanString(shape()));
+ pieces.push_back(shape_to_string(shape()));
pieces.push_back(" {\n");
for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
pieces.push_back(tensorflow::strings::Printf(" { /*i0=%lld*/\n", i0));
@@ -649,7 +657,7 @@ string Literal::ToString() const {
}
pieces.push_back("}");
} else if (ShapeUtil::Rank(shape()) == 5) {
- pieces.push_back(ShapeUtil::HumanString(shape()));
+ pieces.push_back(shape_to_string(shape()));
pieces.push_back(" {\n");
for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
pieces.push_back(tensorflow::strings::Printf(" { /*i0=%lld*/\n", i0));
@@ -676,7 +684,7 @@ string Literal::ToString() const {
}
pieces.push_back("}");
} else {
- pieces.push_back(ShapeUtil::HumanString(shape()));
+ pieces.push_back(shape_to_string(shape()));
pieces.push_back(" {...}");
}
diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h
index a1e288829f..667f926c46 100644
--- a/tensorflow/compiler/xla/literal_util.h
+++ b/tensorflow/compiler/xla/literal_util.h
@@ -450,7 +450,7 @@ class Literal {
tensorflow::Status ValidateLiteral() const;
// Returns a string representation of the literal value.
- string ToString() const;
+ string ToString(bool print_layout = false) const;
// Invokes the "per cell" callback for each element in the provided
// literal with the element's indices and a string representation of
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index c6f6c6c38b..7cf24641b5 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -1780,7 +1780,6 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
- "//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
],
)
@@ -1851,7 +1850,6 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
- "//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
],
)
diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc
index 7c4626e78a..3601a790c4 100644
--- a/tensorflow/compiler/xla/service/hlo_cse_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc
@@ -79,12 +79,12 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) {
// Test that two identical constants with different layouts are commoned if
// the pass is not layout sensitive.
auto builder = HloComputation::Builder(TestName());
- auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
- test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
- /*minor_to_major=*/{0, 1})));
- auto constant2 = builder.AddInstruction(HloInstruction::CreateConstant(
- test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
- /*minor_to_major=*/{1, 0})));
+ auto constant1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+ {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
+ auto constant2 = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+ {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
constant1->shape(), HloOpcode::kAdd, constant1, constant2));
@@ -111,12 +111,12 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) {
// Test that two identical constants with different layouts are *not* commoned
// if the pass is layout sensitive.
auto builder = HloComputation::Builder(TestName());
- auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
- test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
- /*minor_to_major=*/{0, 1})));
- auto constant2 = builder.AddInstruction(HloInstruction::CreateConstant(
- test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
- /*minor_to_major=*/{1, 0})));
+ auto constant1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+ {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
+ auto constant2 = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+ {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
constant1->shape(), HloOpcode::kAdd, constant1, constant2));
diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc
index c39ff52230..d51c0d1dfb 100644
--- a/tensorflow/compiler/xla/service/layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc
@@ -131,10 +131,10 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) {
std::vector<std::initializer_list<int64>> minor_to_majors = {{0, 1}, {1, 0}};
for (auto& minor_to_major : minor_to_majors) {
auto builder = HloComputation::Builder(TestName());
- auto constant_literal1 = test_utils::CreateR2LiteralWithLayout<float>(
- {{1.0, 2.0}, {3.0, 4.0}}, minor_to_major);
- auto constant_literal2 = test_utils::CreateR2LiteralWithLayout<float>(
- {{5.0, 6.0}, {7.0, 8.0}}, minor_to_major);
+ auto constant_literal1 = Literal::CreateR2WithLayout<float>(
+ {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout(minor_to_major));
+ auto constant_literal2 = Literal::CreateR2WithLayout<float>(
+ {{5.0, 6.0}, {7.0, 8.0}}, LayoutUtil::MakeLayout(minor_to_major));
Shape ashape = constant_literal1->shape();
auto constant1 = builder.AddInstruction(
@@ -181,12 +181,12 @@ TEST_F(LayoutAssignmentTest, TupleLayout) {
// Verify the layouts of a tuple are assigned properly (the element layouts
// match their source).
auto builder = HloComputation::Builder(TestName());
- auto constant0 = builder.AddInstruction(HloInstruction::CreateConstant(
- test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
- {0, 1})));
- auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
- test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
- {1, 0})));
+ auto constant0 = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+ {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
+ auto constant1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+ {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant0, constant1}));
@@ -218,12 +218,12 @@ TEST_F(LayoutAssignmentTest, TupleLayout) {
TEST_F(LayoutAssignmentTest, TupleSelect) {
// Verify layouts of a select with tuple operands is assigned properly.
auto builder = HloComputation::Builder(TestName());
- auto constant0 = builder.AddInstruction(HloInstruction::CreateConstant(
- test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
- {0, 1})));
- auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
- test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
- {1, 0})));
+ auto constant0 = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+ {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
+ auto constant1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+ {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
auto tuple0 = builder.AddInstruction(
HloInstruction::CreateTuple({constant0, constant1}));
auto tuple1 = builder.AddInstruction(
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 4e1be24b61..2333a30ad5 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -61,13 +61,14 @@ generate_backend_test_macros()
cc_library(
name = "test_utils",
- testonly = True,
+ srcs = ["test_utils.cc"],
hdrs = ["test_utils.h"],
deps = [
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:lib",
],
)
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h
index 7cfc276ec1..2c37466ff2 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.h
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.h
@@ -469,8 +469,7 @@ template <typename NativeT>
std::vector<NativeT> ClientLibraryTestBase::CreatePseudorandomR1(
const int width, NativeT min_value, NativeT max_value, uint32 seed) {
std::vector<NativeT> result(width);
- test_utils::PseudorandomGenerator<NativeT> generator(min_value, max_value,
- seed);
+ PseudorandomGenerator<NativeT> generator(min_value, max_value, seed);
for (int i = 0; i < width; ++i) {
result[i] = generator.get();
}
@@ -482,8 +481,7 @@ std::unique_ptr<Array2D<NativeT>> ClientLibraryTestBase::CreatePseudorandomR2(
const int rows, const int cols, NativeT min_value, NativeT max_value,
uint32 seed) {
auto result = MakeUnique<Array2D<NativeT>>(rows, cols);
- test_utils::PseudorandomGenerator<NativeT> generator(min_value, max_value,
- seed);
+ PseudorandomGenerator<NativeT> generator(min_value, max_value, seed);
for (int y = 0; y < rows; ++y) {
for (int x = 0; x < cols; ++x) {
(*result)(y, x) = generator.get();
diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc
index 0853feeebd..183bcf1dd3 100644
--- a/tensorflow/compiler/xla/tests/client_test.cc
+++ b/tensorflow/compiler/xla/tests/client_test.cc
@@ -54,8 +54,8 @@ TEST_F(ClientTest, ExecuteWithLayout) {
.ConsumeValueOrDie();
std::unique_ptr<Literal> expected_literal =
- test_utils::CreateR2LiteralWithLayout<int32>({{11, 22}, {33, 44}},
- transfer_layout);
+ Literal::CreateR2WithLayout<int32>(
+ {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout));
auto computed = client_->Transfer(*data, &expected_literal->shape());
diff --git a/tensorflow/compiler/xla/tests/compilation_cache_test.cc b/tensorflow/compiler/xla/tests/compilation_cache_test.cc
index 707e439245..0f780fa87e 100644
--- a/tensorflow/compiler/xla/tests/compilation_cache_test.cc
+++ b/tensorflow/compiler/xla/tests/compilation_cache_test.cc
@@ -138,13 +138,13 @@ XLA_TEST_F(CompilationCacheTest, DifferentParameterLayouts) {
// layouts. Use these arrays as parameters to a simple computation. If the
// layout of the array changes then computation should be recompiled (cache
// miss).
- auto rowmaj_array = test_utils::CreateR2LiteralWithLayout(
- {{1.0f, 2.0f}, {3.0f, 4.0f}}, /*minor_to_major=*/{1, 0});
+ auto rowmaj_array = Literal::CreateR2WithLayout(
+ {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({1, 0}));
auto rowmaj_handle =
client_->TransferToServer(*rowmaj_array).ConsumeValueOrDie();
- auto colmaj_array = test_utils::CreateR2LiteralWithLayout(
- {{1.0f, 2.0f}, {3.0f, 4.0f}}, /*minor_to_major=*/{0, 1});
+ auto colmaj_array = Literal::CreateR2WithLayout(
+ {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1}));
auto colmaj_handle =
client_->TransferToServer(*colmaj_array).ConsumeValueOrDie();
diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc
index d423c78476..5226a78386 100644
--- a/tensorflow/compiler/xla/tests/compute_constant_test.cc
+++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc
@@ -264,8 +264,8 @@ XLA_TEST_F(ComputeConstantTest, Layout) {
ASSERT_TRUE(computed.ok()) << computed.status();
std::unique_ptr<Literal> expected_literal =
- test_utils::CreateR2LiteralWithLayout<int32>({{11, 22}, {33, 44}},
- layout);
+ Literal::CreateR2WithLayout<int32>({{11, 22}, {33, 44}},
+ LayoutUtil::MakeLayout(layout));
LiteralTestUtil::AssertEqualShapesAndLayouts(
expected_literal->shape(), computed.ValueOrDie()->shape());
LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie());
diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc
index c4e422b506..b72dd2707c 100644
--- a/tensorflow/compiler/xla/tests/dot_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc
@@ -177,15 +177,15 @@ void DotOperationTest::TestSquareMatrixDot(bool lhs_row_major,
bool rhs_row_major) {
auto lhs_handle =
client_
- ->TransferToServer(*test_utils::CreateR2LiteralWithLayout<Element>(
+ ->TransferToServer(*Literal::CreateR2WithLayout<Element>(
{{1.0, 2.0}, {3.0, -4.0}},
- MinorToMajorForIsRowMajor(lhs_row_major)))
+ LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(lhs_row_major))))
.ConsumeValueOrDie();
auto rhs_handle =
client_
- ->TransferToServer(*test_utils::CreateR2LiteralWithLayout<Element>(
+ ->TransferToServer(*Literal::CreateR2WithLayout<Element>(
{{1.0, 6.0}, {7.0, -4.0}},
- MinorToMajorForIsRowMajor(rhs_row_major)))
+ LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(rhs_row_major))))
.ConsumeValueOrDie();
ComputationBuilder builder(client_, TestName());
@@ -362,15 +362,15 @@ void DotOperationTest::TestNonsquareMatrixDot(bool lhs_row_major,
bool rhs_row_major) {
auto lhs_handle =
client_
- ->TransferToServer(*test_utils::CreateR2LiteralWithLayout<Element>(
+ ->TransferToServer(*Literal::CreateR2WithLayout<Element>(
{{1.0, 2.0, 3.0}, {3.0, -4.0, -1.0}},
- MinorToMajorForIsRowMajor(lhs_row_major)))
+ LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(lhs_row_major))))
.ConsumeValueOrDie();
auto rhs_handle =
client_
- ->TransferToServer(*test_utils::CreateR2LiteralWithLayout<Element>(
+ ->TransferToServer(*Literal::CreateR2WithLayout<Element>(
{{1.0, 6.0}, {2.0, 3.0}, {7.0, -4.0}},
- MinorToMajorForIsRowMajor(rhs_row_major)))
+ LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(rhs_row_major))))
.ConsumeValueOrDie();
ComputationBuilder builder(client_, TestName());
@@ -420,13 +420,14 @@ XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64) {
XLA_TEST_F(DotOperationTest, MatrixVectorC64) {
auto lhs_handle =
client_
- ->TransferToServer(*test_utils::CreateR2LiteralWithLayout<complex64>(
- {{1.0, 2.0, 3.0, -4.0}}, {1, 0}))
+ ->TransferToServer(*Literal::CreateR2WithLayout<complex64>(
+ {{1.0, 2.0, 3.0, -4.0}}, LayoutUtil::MakeLayout({1, 0})))
.ConsumeValueOrDie();
auto rhs_handle =
client_
- ->TransferToServer(*test_utils::CreateR2LiteralWithLayout<complex64>(
- {{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}}, {1, 0}))
+ ->TransferToServer(*Literal::CreateR2WithLayout<complex64>(
+ {{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}},
+ LayoutUtil::MakeLayout({1, 0})))
.ConsumeValueOrDie();
ComputationBuilder builder(client_, TestName());
diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc
index 329b53012f..a196e250d1 100644
--- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc
+++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc
@@ -136,16 +136,14 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) {
auto computation = builder.Build().ConsumeValueOrDie();
// Create x as a col-major array.
- auto x_array = LiteralToShapedBuffer(
- *test_utils::CreateR2LiteralWithLayout({{1.0f, 2.0f}, {3.0f, 4.0f}},
- /*minor_to_major=*/{0, 1}));
+ auto x_array = LiteralToShapedBuffer(*Literal::CreateR2WithLayout(
+ {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1})));
EXPECT_TRUE(LayoutUtil::Equal(x_array->shape().layout(),
LayoutUtil::MakeLayout({0, 1})));
// Create y as a row-major array.
- auto y_array = LiteralToShapedBuffer(
- *test_utils::CreateR2LiteralWithLayout({{10.0f, 20.0f}, {30.0f, 40.0f}},
- /*minor_to_major=*/{1, 0}));
+ auto y_array = LiteralToShapedBuffer(*Literal::CreateR2WithLayout(
+ {{10.0f, 20.0f}, {30.0f, 40.0f}}, LayoutUtil::MakeLayout({1, 0})));
EXPECT_TRUE(LayoutUtil::Equal(y_array->shape().layout(),
LayoutUtil::MakeLayout({1, 0})));
diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc
index 2ef392508d..2b0f7e6e80 100644
--- a/tensorflow/compiler/xla/tests/map_test.cc
+++ b/tensorflow/compiler/xla/tests/map_test.cc
@@ -405,13 +405,13 @@ TEST_F(MapTest, MapBinaryAdder) {
// for Map that used to fail in shape inference (b/28989438).
XLA_TEST_F(MapTest, AddWithMixedLayouts) {
ComputationBuilder builder(client_, TestName());
- std::unique_ptr<Literal> param0_literal =
- test_utils::CreateR2LiteralWithLayout({{1, 2}, {3, 4}}, {1, 0});
+ std::unique_ptr<Literal> param0_literal = Literal::CreateR2WithLayout(
+ {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({1, 0}));
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> param1_literal =
- test_utils::CreateR2LiteralWithLayout({{10, 20}, {30, 40}}, {0, 1});
+ std::unique_ptr<Literal> param1_literal = Literal::CreateR2WithLayout(
+ {{10, 20}, {30, 40}}, LayoutUtil::MakeLayout({0, 1}));
std::unique_ptr<GlobalData> param1_data =
client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
new file mode 100644
index 0000000000..cdd3d66bbb
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -0,0 +1,120 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/tests/test_utils.h"
+
+#include "tensorflow/compiler/xla/primitive_util.h"
+
+namespace xla {
+
+namespace {
+
+template <typename FloatT>
+void PopulateWithRandomFloatingPointData(Literal* literal) {
+ CHECK_EQ(literal->shape().element_type(),
+ primitive_util::NativeToPrimitiveType<FloatT>());
+ std::minstd_rand0 engine;
+ std::uniform_real_distribution<FloatT> generator(0.0f, 1.0f);
+ TF_CHECK_OK(literal->Populate<FloatT>(
+ [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
+ return generator(engine);
+ }));
+}
+
+template <typename IntT>
+void PopulateWithRandomIntegralData(Literal* literal) {
+ CHECK_EQ(literal->shape().element_type(),
+ primitive_util::NativeToPrimitiveType<IntT>());
+ std::minstd_rand0 engine;
+ std::uniform_int_distribution<IntT> generator(
+ std::numeric_limits<IntT>::lowest(), std::numeric_limits<IntT>::max());
+ TF_CHECK_OK(literal->Populate<IntT>(
+ [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
+ return generator(engine);
+ }));
+}
+
+} // namespace
+
+StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape) {
+ if (ShapeUtil::IsTuple(shape)) {
+ std::vector<std::unique_ptr<Literal>> elements;
+ for (const Shape& element_shape : shape.tuple_shapes()) {
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> element,
+ MakeFakeLiteral(element_shape));
+ elements.push_back(std::move(element));
+ }
+ return Literal::MakeTupleOwned(std::move(elements));
+ }
+ std::unique_ptr<Literal> literal = Literal::CreateFromShape(shape);
+ switch (shape.element_type()) {
+ case F32:
+ PopulateWithRandomFloatingPointData<float>(literal.get());
+ break;
+ case F64:
+ PopulateWithRandomFloatingPointData<double>(literal.get());
+ break;
+ case S8:
+ PopulateWithRandomIntegralData<int8>(literal.get());
+ break;
+ case U8:
+ PopulateWithRandomIntegralData<uint8>(literal.get());
+ break;
+ case S16:
+ PopulateWithRandomIntegralData<int16>(literal.get());
+ break;
+ case U16:
+ PopulateWithRandomIntegralData<uint16>(literal.get());
+ break;
+ case S32:
+ PopulateWithRandomIntegralData<int32>(literal.get());
+ break;
+ case U32:
+ PopulateWithRandomIntegralData<uint32>(literal.get());
+ break;
+ case S64:
+ PopulateWithRandomIntegralData<int64>(literal.get());
+ break;
+ case U64:
+ PopulateWithRandomIntegralData<uint64>(literal.get());
+ break;
+ case PRED: {
+ std::uniform_int_distribution<int> generator(0, 1);
+ std::minstd_rand0 engine;
+ TF_CHECK_OK(literal->Populate<bool>(
+ [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
+ return generator(engine);
+ }));
+ break;
+ }
+ default:
+ return Unimplemented("Unsupported type for fake literal generation: %s",
+ ShapeUtil::HumanString(shape).c_str());
+ }
+ return std::move(literal);
+}
+
+StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
+ const HloModule& module) {
+ std::vector<std::unique_ptr<Literal>> arguments;
+ for (const ShapeLayout& shape_layout :
+ module.config().entry_computation_layout().parameter_layouts()) {
+ TF_ASSIGN_OR_RETURN(auto literal, MakeFakeLiteral(shape_layout.shape()));
+ arguments.push_back(std::move(literal));
+ }
+ return std::move(arguments);
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h
index f3a522b05e..12d5255fce 100644
--- a/tensorflow/compiler/xla/tests/test_utils.h
+++ b/tensorflow/compiler/xla/tests/test_utils.h
@@ -23,12 +23,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
-namespace test_utils {
// A class which generates pseudorandom numbers of a given type within a given
// range. Not cryptographically secure and likely not perfectly evenly
@@ -53,63 +53,15 @@ class PseudorandomGenerator {
std::mt19937 generator_;
};
-// Convenience function for creating a rank-2 array with arbitrary layout.
-template <typename NativeT>
-std::unique_ptr<Literal> CreateR2LiteralWithLayout(
- std::initializer_list<std::initializer_list<NativeT>> values,
- tensorflow::gtl::ArraySlice<int64> minor_to_major) {
- auto literal = MakeUnique<Literal>();
- const int64 d0 = values.size();
- const int64 d1 = values.begin()->size();
- literal.get()->PopulateWithValue<NativeT>(0, {d0, d1});
- *literal->mutable_shape()->mutable_layout() =
- LayoutUtil::MakeLayout(minor_to_major);
- TF_CHECK_OK(ShapeUtil::ValidateShape(literal->shape()));
-
- int64 dim0 = 0;
- for (auto inner_list : values) {
- int64 dim1 = 0;
- for (auto value : inner_list) {
- literal.get()->Set({dim0, dim1}, value);
- ++dim1;
- }
- ++dim0;
- }
- return literal;
-}
+// Generates fake data in a literal of the given shape, or returns an error
+// status if the element type is currently unhandled for fake data generation.
+StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape);
-// Convenience function for creating a rank-3 array with arbitrary layout.
-template <typename NativeT>
-std::unique_ptr<Literal> CreateR3LiteralWithLayout(
- std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
- values,
- tensorflow::gtl::ArraySlice<int64> minor_to_major) {
- auto literal = MakeUnique<Literal>();
- const int64 d0 = values.size();
- const int64 d1 = values.begin()->size();
- const int64 d2 = values.begin()->begin()->size();
- literal.get()->PopulateWithValue<NativeT>(0, {d0, d1, d2});
- *literal->mutable_shape()->mutable_layout() =
- LayoutUtil::MakeLayout(minor_to_major);
- TF_CHECK_OK(ShapeUtil::ValidateShape(literal->shape()));
-
- int64 dim0 = 0;
- for (auto inner_list : values) {
- int64 dim1 = 0;
- for (auto inner_inner_list : inner_list) {
- int64 dim2 = 0;
- for (auto value : inner_inner_list) {
- literal.get()->Set({dim0, dim1, dim2}, value);
- ++dim2;
- }
- ++dim1;
- }
- ++dim0;
- }
- return literal;
-}
+// Generates a vector of arguments containing fake data. The number, shape and
+// layout of the arguments is appropriate for given HLO module.
+StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
+ const HloModule& module);
-} // namespace test_utils
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_UTILS_H_
diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD
index 759921dce5..091fa0c3ec 100644
--- a/tensorflow/compiler/xla/tools/BUILD
+++ b/tensorflow/compiler/xla/tools/BUILD
@@ -88,6 +88,7 @@ cc_library(
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/lib:testing",
"//tensorflow/compiler/xla/service:session_proto",
+ "//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
],
diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc
index 89b26b8916..503e7d456e 100644
--- a/tensorflow/compiler/xla/tools/replay_computation.cc
+++ b/tensorflow/compiler/xla/tools/replay_computation.cc
@@ -45,6 +45,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/threadpool.h"