aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc')
-rw-r--r--tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc42
1 files changed, 20 insertions, 22 deletions
diff --git a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc
index e692b8c5d5..091a5d2cac 100644
--- a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc
+++ b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc
@@ -38,7 +38,7 @@ namespace {
class RoundTripPackedLiteralTest : public ClientLibraryTestBase {
protected:
// Sends the literal to the server and retrieves it back.
- std::unique_ptr<Literal> RoundTripToServer(const Literal& original) {
+ Literal RoundTripToServer(const Literal& original) {
std::unique_ptr<GlobalData> data =
client_->TransferToServer(original).ConsumeValueOrDie();
return client_->Transfer(*data).ConsumeValueOrDie();
@@ -59,12 +59,12 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR1F32Length2) {
std::unique_ptr<tensorflow::RandomAccessFile> f;
TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(fname, &f));
PackedLiteralReader reader(f.release());
- std::unique_ptr<Literal> actual =
+ Literal actual =
reader.Read(ShapeUtil::MakeShape(F32, {2})).ConsumeValueOrDie();
EXPECT_TRUE(reader.IsExhausted());
- EXPECT_EQ(42.0, actual->Get<float>({0}));
- EXPECT_EQ(24.0, actual->Get<float>({1}));
+ EXPECT_EQ(42.0, actual.Get<float>({0}));
+ EXPECT_EQ(24.0, actual.Get<float>({1}));
}
TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) {
@@ -87,18 +87,17 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) {
std::unique_ptr<tensorflow::RandomAccessFile> f;
TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(fname, &f));
PackedLiteralReader reader(f.release());
- std::unique_ptr<Literal> actual =
- reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout)
- .ConsumeValueOrDie();
+ Literal actual = reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout)
+ .ConsumeValueOrDie();
EXPECT_TRUE(reader.IsExhausted());
- EXPECT_EQ(42.0f, actual->Get<float>({0, 0}));
- EXPECT_EQ(24.0f, actual->Get<float>({0, 1}));
- EXPECT_EQ(64.0f, actual->Get<float>({1, 0}));
- EXPECT_EQ(46.0f, actual->Get<float>({1, 1}));
+ EXPECT_EQ(42.0f, actual.Get<float>({0, 0}));
+ EXPECT_EQ(24.0f, actual.Get<float>({0, 1}));
+ EXPECT_EQ(64.0f, actual.Get<float>({1, 0}));
+ EXPECT_EQ(46.0f, actual.Get<float>({1, 1}));
- std::unique_ptr<Literal> round_tripped = RoundTripToServer(*actual);
- EXPECT_TRUE(LiteralTestUtil::Equal(*round_tripped, *actual));
+ Literal round_tripped = RoundTripToServer(actual);
+ EXPECT_TRUE(LiteralTestUtil::Equal(round_tripped, actual));
}
TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) {
@@ -121,18 +120,17 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) {
std::unique_ptr<tensorflow::RandomAccessFile> f;
TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(fname, &f));
PackedLiteralReader reader(f.release());
- std::unique_ptr<Literal> actual =
- reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout)
- .ConsumeValueOrDie();
+ Literal actual = reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout)
+ .ConsumeValueOrDie();
EXPECT_TRUE(reader.IsExhausted());
- EXPECT_EQ(42.0f, actual->Get<float>({0, 0}));
- EXPECT_EQ(24.0f, actual->Get<float>({1, 0}));
- EXPECT_EQ(64.0f, actual->Get<float>({0, 1}));
- EXPECT_EQ(46.0f, actual->Get<float>({1, 1}));
+ EXPECT_EQ(42.0f, actual.Get<float>({0, 0}));
+ EXPECT_EQ(24.0f, actual.Get<float>({1, 0}));
+ EXPECT_EQ(64.0f, actual.Get<float>({0, 1}));
+ EXPECT_EQ(46.0f, actual.Get<float>({1, 1}));
- std::unique_ptr<Literal> round_tripped = RoundTripToServer(*actual);
- EXPECT_TRUE(LiteralTestUtil::Equal(*round_tripped, *actual));
+ Literal round_tripped = RoundTripToServer(actual);
+ EXPECT_TRUE(LiteralTestUtil::Equal(round_tripped, actual));
}
} // namespace