aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/literal_test_util.cc
diff options
context:
space:
mode:
authorGravatar Kay Zhu <kayzhu@google.com>2018-05-09 13:07:35 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-09 13:47:51 -0700
commite1347ba769b98e260d36e895be2963af35c88d18 (patch)
tree3f8c5d4edaa71035459f08d9520a4a0fdbcaadf5 /tensorflow/compiler/xla/tests/literal_test_util.cc
parenta4afe20fb4663c0f3b7f1b0086fe1c97557fea7b (diff)
[XLA] First step in adding Literal slice classes, to improve interface safety
and prepare for enabling more efficient interfacing from Tensor to Literal to reduce host to device latency. More specically: * Introducing a new LiteralBase abstract base class that contains all immutable methods of from the old Literal class. * Introducing a subclass LiteralSlice to replace original LiteralView class. LiteralSlice class is read-only and does not own Shape nor any buffer through the Pieces. Change a number of callers to use LiteralSlice directly. * Change Literal class to explicitly own the underlying Shape as well as owning the underlying buffer via Piece. * Conversion from Literal to LiteralSlice is now done via an implicit conversion constructor instead of inheritance. * Decouple ShapeTree from Literal classes. * Use copy-and-swap for assignment constructors. * Other minor cleanups. PiperOrigin-RevId: 196016576
Diffstat (limited to 'tensorflow/compiler/xla/tests/literal_test_util.cc')
-rw-r--r--tensorflow/compiler/xla/tests/literal_test_util.cc58
1 files changed, 29 insertions, 29 deletions
diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc
index c28f79ae38..868876c72d 100644
--- a/tensorflow/compiler/xla/tests/literal_test_util.cc
+++ b/tensorflow/compiler/xla/tests/literal_test_util.cc
@@ -111,7 +111,7 @@ namespace {
// Return a literal with all arrays of type FromNativeT converted to type
// ToNativeT in the given literal.
template <typename FromNativeT, typename ToNativeT>
-std::unique_ptr<Literal> ConvertType(const Literal& literal) {
+std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
// First construct shape of the result.
Shape result_shape(literal.shape());
ShapeUtil::ForEachMutableSubshape(
@@ -150,12 +150,12 @@ std::unique_ptr<Literal> ConvertType(const Literal& literal) {
} // namespace
/* static */ std::unique_ptr<Literal> LiteralTestUtil::ConvertBF16ToF32(
- const Literal& literal) {
+ LiteralSlice literal) {
return ConvertType<bfloat16, float>(literal);
}
/* static */ std::unique_ptr<Literal> LiteralTestUtil::ConvertF32ToBF16(
- const Literal& literal) {
+ LiteralSlice literal) {
return ConvertType<float, bfloat16>(literal);
}
@@ -237,7 +237,7 @@ template <>
// actual literal and compares their values elementwise. Returns true if all
// elements are equal.
template <typename NativeT>
-bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual,
+bool ExpectLiteralsEqual(LiteralSlice expected, LiteralSlice actual,
tensorflow::gtl::MutableArraySlice<int64> multi_index,
int64 dimension) {
if (dimension == expected.shape().dimensions_size()) {
@@ -259,8 +259,8 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual,
} // namespace
-/* static */ void LiteralTestUtil::ExpectEqual(const Literal& expected,
- const Literal& actual,
+/* static */ void LiteralTestUtil::ExpectEqual(LiteralSlice expected,
+ LiteralSlice actual,
const string& message) {
EXPECT_TRUE(Equal(expected, actual))
<< "expected:\n"
@@ -269,13 +269,13 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual,
<< (message.empty() ? "" : StrCat("\nmessage: ", message));
}
-/* static */ void LiteralTestUtil::ExpectNotEqual(const Literal& expected,
- const Literal& actual) {
+/* static */ void LiteralTestUtil::ExpectNotEqual(LiteralSlice expected,
+ LiteralSlice actual) {
EXPECT_FALSE(Equal(expected, actual));
}
/* static */ ::testing::AssertionResult LiteralTestUtil::Equal(
- const Literal& expected, const Literal& actual) {
+ LiteralSlice expected, LiteralSlice actual) {
VLOG(1) << "expected:";
XLA_VLOG_LINES(1, expected.ToString());
VLOG(1) << "actual:";
@@ -324,9 +324,9 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual,
SCOPED_TRACE(StrCat("Tuple index ", i, " in ",
ShapeUtil::HumanString(expected.shape())));
- // Create LiteralViews of the expected and actual elements.
- auto result = Equal(LiteralView::Create(expected, {i}),
- LiteralView::Create(actual, {i}));
+ // Create LiteralSlices of the expected and actual elements.
+ auto result =
+ Equal(LiteralSlice(expected, {i}), LiteralSlice(actual, {i}));
tuple_match = tuple_match ? !!result : false;
}
match = tuple_match;
@@ -368,7 +368,7 @@ int64 RecursiveElementCount(const Shape& shape) {
// 3 minutes. The utility of printing a literal with >1000 elements is
// questionable, especially when writing the Literal proto to disk is orders
// of magnitude faster.
-string TruncateHugeLiteral(const Literal& literal) {
+string TruncateHugeLiteral(LiteralSlice literal) {
return RecursiveElementCount(literal.shape()) < 1000
? literal.ToString()
: "[TRUNCATED, Literal with more than 1000 values]";
@@ -435,8 +435,8 @@ class NearComparator {
// result. The assertion result is successful if all actual and expected
// elements are within the given error bound. In case of error, the assertion
// result contains a detailed error message in case of failure.
- static ::testing::AssertionResult Compare(const Literal& expected,
- const Literal& actual,
+ static ::testing::AssertionResult Compare(LiteralSlice expected,
+ LiteralSlice actual,
ErrorSpec error,
bool detailed_message) {
NearComparator<NativeT> comparator(expected, actual, error,
@@ -472,7 +472,7 @@ class NearComparator {
}
};
- explicit NearComparator(const Literal& expected, const Literal& actual,
+ explicit NearComparator(LiteralSlice expected, LiteralSlice actual,
ErrorSpec error, bool detailed_message)
: expected_(expected),
actual_(actual),
@@ -649,7 +649,7 @@ class NearComparator {
}
// Writes the given literal to a file in the test temporary directory.
- void WriteLiteralToTempFile(const Literal& literal, const string& name) {
+ void WriteLiteralToTempFile(LiteralSlice literal, const string& name) {
int64 now_usec = tensorflow::Env::Default()->NowMicros();
string filename = tensorflow::io::JoinPath(
tensorflow::testing::TmpDir(),
@@ -733,8 +733,8 @@ class NearComparator {
}
// 'actual' and 'expected' literals being compared.
- const Literal& expected_;
- const Literal& actual_;
+ LiteralSlice expected_;
+ LiteralSlice actual_;
// The error bounds of the comparison.
ErrorSpec error_;
@@ -794,8 +794,8 @@ constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds;
// Helper function for comparing two literals for nearness. Handles tuple-shapes
// via recursion. shape_index is the ShapeIndex of expected (or actual)
// currently being compared.
-::testing::AssertionResult NearHelper(const Literal& expected,
- const Literal& actual,
+::testing::AssertionResult NearHelper(LiteralSlice expected,
+ LiteralSlice actual,
const ErrorSpec& error,
bool detailed_message,
const ShapeIndex& shape_index) {
@@ -807,8 +807,8 @@ constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds;
if (ShapeUtil::IsTuple(expected.shape())) {
for (int64 i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
- const auto expected_element = LiteralView::Create(expected, {i});
- const auto actual_element = LiteralView::Create(actual, {i});
+ const auto expected_element = LiteralSlice(expected, {i});
+ const auto actual_element = LiteralSlice(actual, {i});
ShapeIndex element_index = shape_index;
element_index.push_back(i);
::testing::AssertionResult res =
@@ -874,14 +874,14 @@ constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds;
} // namespace
/* static */ ::testing::AssertionResult LiteralTestUtil::Near(
- const Literal& expected, const Literal& actual, const ErrorSpec& error,
+ LiteralSlice expected, LiteralSlice actual, const ErrorSpec& error,
bool detailed_message) {
return NearHelper(expected, actual, error, detailed_message,
/*shape_index=*/{});
}
-/* static */ void LiteralTestUtil::ExpectNear(const Literal& expected,
- const Literal& actual,
+/* static */ void LiteralTestUtil::ExpectNear(LiteralSlice expected,
+ LiteralSlice actual,
const ErrorSpec& error,
const string& message) {
::testing::AssertionResult res =
@@ -897,7 +897,7 @@ constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds;
}
/*static*/ ::testing::AssertionResult LiteralTestUtil::NearOrEqual(
- const Literal& expected, const Literal& actual,
+ LiteralSlice expected, LiteralSlice actual,
const tensorflow::gtl::optional<ErrorSpec>& error) {
if (error.has_value()) {
VLOG(1) << "Expects near";
@@ -908,7 +908,7 @@ constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds;
}
/*static*/ void LiteralTestUtil::ExpectNearOrEqual(
- const Literal& expected, const Literal& actual,
+ LiteralSlice expected, LiteralSlice actual,
const tensorflow::gtl::optional<ErrorSpec>& error) {
EXPECT_TRUE(NearOrEqual(expected, actual, error));
}
@@ -920,7 +920,7 @@ constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds;
/* static */ std::unique_ptr<Literal> LiteralTestUtil::Reshape(
tensorflow::gtl::ArraySlice<int64> new_dimensions,
- tensorflow::gtl::ArraySlice<int64> minor_to_major, const Literal& literal) {
+ tensorflow::gtl::ArraySlice<int64> minor_to_major, LiteralSlice literal) {
int64 new_num_elements = 1;
for (int64 i = 0; i < new_dimensions.size(); ++i) {
new_num_elements *= new_dimensions[i];