diff options
author | 2018-05-09 13:07:35 -0700 | |
---|---|---|
committer | 2018-05-09 13:47:51 -0700 | |
commit | e1347ba769b98e260d36e895be2963af35c88d18 (patch) | |
tree | 3f8c5d4edaa71035459f08d9520a4a0fdbcaadf5 /tensorflow/compiler/xla/tests/literal_test_util.cc | |
parent | a4afe20fb4663c0f3b7f1b0086fe1c97557fea7b (diff) |
[XLA] First step in adding Literal slice classes, to improve interface safety
and prepare for enabling more efficient interfacing from Tensor to Literal to
reduce host to device latency.
More specically:
* Introducing a new LiteralBase abstract base class that contains all immutable
methods of from the old Literal class.
* Introducing a subclass LiteralSlice to replace original LiteralView class.
LiteralSlice class is read-only and does not own Shape nor any buffer through
the Pieces. Change a number of callers to use LiteralSlice directly.
* Change Literal class to explicitly own the underlying Shape as well as owning
the underlying buffer via Piece.
* Conversion from Literal to LiteralSlice is now done via an implicit
conversion constructor instead of inheritance.
* Decouple ShapeTree from Literal classes.
* Use copy-and-swap for assignment constructors.
* Other minor cleanups.
PiperOrigin-RevId: 196016576
Diffstat (limited to 'tensorflow/compiler/xla/tests/literal_test_util.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/literal_test_util.cc | 58 |
1 files changed, 29 insertions, 29 deletions
diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index c28f79ae38..868876c72d 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -111,7 +111,7 @@ namespace { // Return a literal with all arrays of type FromNativeT converted to type // ToNativeT in the given literal. template <typename FromNativeT, typename ToNativeT> -std::unique_ptr<Literal> ConvertType(const Literal& literal) { +std::unique_ptr<Literal> ConvertType(LiteralSlice literal) { // First construct shape of the result. Shape result_shape(literal.shape()); ShapeUtil::ForEachMutableSubshape( @@ -150,12 +150,12 @@ std::unique_ptr<Literal> ConvertType(const Literal& literal) { } // namespace /* static */ std::unique_ptr<Literal> LiteralTestUtil::ConvertBF16ToF32( - const Literal& literal) { + LiteralSlice literal) { return ConvertType<bfloat16, float>(literal); } /* static */ std::unique_ptr<Literal> LiteralTestUtil::ConvertF32ToBF16( - const Literal& literal) { + LiteralSlice literal) { return ConvertType<float, bfloat16>(literal); } @@ -237,7 +237,7 @@ template <> // actual literal and compares their values elementwise. Returns true if all // elements are equal. template <typename NativeT> -bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, +bool ExpectLiteralsEqual(LiteralSlice expected, LiteralSlice actual, tensorflow::gtl::MutableArraySlice<int64> multi_index, int64 dimension) { if (dimension == expected.shape().dimensions_size()) { @@ -259,8 +259,8 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, } // namespace -/* static */ void LiteralTestUtil::ExpectEqual(const Literal& expected, - const Literal& actual, +/* static */ void LiteralTestUtil::ExpectEqual(LiteralSlice expected, + LiteralSlice actual, const string& message) { EXPECT_TRUE(Equal(expected, actual)) << "expected:\n" @@ -269,13 +269,13 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, << (message.empty() ? "" : StrCat("\nmessage: ", message)); } -/* static */ void LiteralTestUtil::ExpectNotEqual(const Literal& expected, - const Literal& actual) { +/* static */ void LiteralTestUtil::ExpectNotEqual(LiteralSlice expected, + LiteralSlice actual) { EXPECT_FALSE(Equal(expected, actual)); } /* static */ ::testing::AssertionResult LiteralTestUtil::Equal( - const Literal& expected, const Literal& actual) { + LiteralSlice expected, LiteralSlice actual) { VLOG(1) << "expected:"; XLA_VLOG_LINES(1, expected.ToString()); VLOG(1) << "actual:"; @@ -324,9 +324,9 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, SCOPED_TRACE(StrCat("Tuple index ", i, " in ", ShapeUtil::HumanString(expected.shape()))); - // Create LiteralViews of the expected and actual elements. - auto result = Equal(LiteralView::Create(expected, {i}), - LiteralView::Create(actual, {i})); + // Create LiteralSlices of the expected and actual elements. + auto result = + Equal(LiteralSlice(expected, {i}), LiteralSlice(actual, {i})); tuple_match = tuple_match ? !!result : false; } match = tuple_match; @@ -368,7 +368,7 @@ int64 RecursiveElementCount(const Shape& shape) { // 3 minutes. The utility of printing a literal with >1000 elements is // questionable, especially when writing the Literal proto to disk is orders // of magnitude faster. -string TruncateHugeLiteral(const Literal& literal) { +string TruncateHugeLiteral(LiteralSlice literal) { return RecursiveElementCount(literal.shape()) < 1000 ? literal.ToString() : "[TRUNCATED, Literal with more than 1000 values]"; @@ -435,8 +435,8 @@ class NearComparator { // result. The assertion result is successful if all actual and expected // elements are within the given error bound. In case of error, the assertion // result contains a detailed error message in case of failure. - static ::testing::AssertionResult Compare(const Literal& expected, - const Literal& actual, + static ::testing::AssertionResult Compare(LiteralSlice expected, + LiteralSlice actual, ErrorSpec error, bool detailed_message) { NearComparator<NativeT> comparator(expected, actual, error, @@ -472,7 +472,7 @@ class NearComparator { } }; - explicit NearComparator(const Literal& expected, const Literal& actual, + explicit NearComparator(LiteralSlice expected, LiteralSlice actual, ErrorSpec error, bool detailed_message) : expected_(expected), actual_(actual), @@ -649,7 +649,7 @@ class NearComparator { } // Writes the given literal to a file in the test temporary directory. - void WriteLiteralToTempFile(const Literal& literal, const string& name) { + void WriteLiteralToTempFile(LiteralSlice literal, const string& name) { int64 now_usec = tensorflow::Env::Default()->NowMicros(); string filename = tensorflow::io::JoinPath( tensorflow::testing::TmpDir(), @@ -733,8 +733,8 @@ class NearComparator { } // 'actual' and 'expected' literals being compared. - const Literal& expected_; - const Literal& actual_; + LiteralSlice expected_; + LiteralSlice actual_; // The error bounds of the comparison. ErrorSpec error_; @@ -794,8 +794,8 @@ constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds; // Helper function for comparing two literals for nearness. Handles tuple-shapes // via recursion. shape_index is the ShapeIndex of expected (or actual) // currently being compared. -::testing::AssertionResult NearHelper(const Literal& expected, - const Literal& actual, +::testing::AssertionResult NearHelper(LiteralSlice expected, + LiteralSlice actual, const ErrorSpec& error, bool detailed_message, const ShapeIndex& shape_index) { @@ -807,8 +807,8 @@ constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds; if (ShapeUtil::IsTuple(expected.shape())) { for (int64 i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { - const auto expected_element = LiteralView::Create(expected, {i}); - const auto actual_element = LiteralView::Create(actual, {i}); + const auto expected_element = LiteralSlice(expected, {i}); + const auto actual_element = LiteralSlice(actual, {i}); ShapeIndex element_index = shape_index; element_index.push_back(i); ::testing::AssertionResult res = @@ -874,14 +874,14 @@ constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds; } // namespace /* static */ ::testing::AssertionResult LiteralTestUtil::Near( - const Literal& expected, const Literal& actual, const ErrorSpec& error, + LiteralSlice expected, LiteralSlice actual, const ErrorSpec& error, bool detailed_message) { return NearHelper(expected, actual, error, detailed_message, /*shape_index=*/{}); } -/* static */ void LiteralTestUtil::ExpectNear(const Literal& expected, - const Literal& actual, +/* static */ void LiteralTestUtil::ExpectNear(LiteralSlice expected, + LiteralSlice actual, const ErrorSpec& error, const string& message) { ::testing::AssertionResult res = @@ -897,7 +897,7 @@ constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds; } /*static*/ ::testing::AssertionResult LiteralTestUtil::NearOrEqual( - const Literal& expected, const Literal& actual, + LiteralSlice expected, LiteralSlice actual, const tensorflow::gtl::optional<ErrorSpec>& error) { if (error.has_value()) { VLOG(1) << "Expects near"; @@ -908,7 +908,7 @@ constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds; } /*static*/ void LiteralTestUtil::ExpectNearOrEqual( - const Literal& expected, const Literal& actual, + LiteralSlice expected, LiteralSlice actual, const tensorflow::gtl::optional<ErrorSpec>& error) { EXPECT_TRUE(NearOrEqual(expected, actual, error)); } @@ -920,7 +920,7 @@ constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds; /* static */ std::unique_ptr<Literal> LiteralTestUtil::Reshape( tensorflow::gtl::ArraySlice<int64> new_dimensions, - tensorflow::gtl::ArraySlice<int64> minor_to_major, const Literal& literal) { + tensorflow::gtl::ArraySlice<int64> minor_to_major, LiteralSlice literal) { int64 new_num_elements = 1; for (int64 i = 0; i < new_dimensions.size(); ++i) { new_num_elements *= new_dimensions[i]; |