diff options
Diffstat (limited to 'tensorflow/compiler/xla/tests/round_trip_transfer_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/round_trip_transfer_test.cc | 51 |
1 files changed, 24 insertions, 27 deletions
diff --git a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc index a8193c2eac..cd5a531603 100644 --- a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc +++ b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc @@ -39,69 +39,67 @@ class RoundTripTransferTest : public ClientLibraryTestBase { void RoundTripTest(const Literal& original) { std::unique_ptr<GlobalData> data = client_->TransferToServer(original).ConsumeValueOrDie(); - std::unique_ptr<Literal> result = - client_->Transfer(*data).ConsumeValueOrDie(); - EXPECT_TRUE(LiteralTestUtil::Equal(original, *result)); + Literal result = client_->Transfer(*data).ConsumeValueOrDie(); + EXPECT_TRUE(LiteralTestUtil::Equal(original, result)); } }; TEST_F(RoundTripTransferTest, R0S32) { - RoundTripTest(*LiteralUtil::CreateR0<int32>(42)); + RoundTripTest(LiteralUtil::CreateR0<int32>(42)); } TEST_F(RoundTripTransferTest, R0F32) { - RoundTripTest(*LiteralUtil::CreateR0<float>(42.0)); + RoundTripTest(LiteralUtil::CreateR0<float>(42.0)); } TEST_F(RoundTripTransferTest, R1F32_Len0) { - RoundTripTest(*LiteralUtil::CreateR1<float>({})); + RoundTripTest(LiteralUtil::CreateR1<float>({})); } TEST_F(RoundTripTransferTest, R1F32_Len2) { - RoundTripTest(*LiteralUtil::CreateR1<float>({42.0, 64.0})); + RoundTripTest(LiteralUtil::CreateR1<float>({42.0, 64.0})); } TEST_F(RoundTripTransferTest, R1F32_Len256) { std::vector<float> values(256); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*LiteralUtil::CreateR1<float>(values)); + RoundTripTest(LiteralUtil::CreateR1<float>(values)); } TEST_F(RoundTripTransferTest, R1F32_Len1024) { std::vector<float> values(1024); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*LiteralUtil::CreateR1<float>(values)); + RoundTripTest(LiteralUtil::CreateR1<float>(values)); } TEST_F(RoundTripTransferTest, R1F32_Len1025) { std::vector<float> values(1025); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*LiteralUtil::CreateR1<float>(values)); + RoundTripTest(LiteralUtil::CreateR1<float>(values)); } TEST_F(RoundTripTransferTest, R1F32_Len4096) { std::vector<float> values(4096); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*LiteralUtil::CreateR1<float>(values)); + RoundTripTest(LiteralUtil::CreateR1<float>(values)); } TEST_F(RoundTripTransferTest, R2F32_Len10x0) { - RoundTripTest( - *LiteralUtil::CreateR2FromArray2D<float>(Array2D<float>(10, 0))); + RoundTripTest(LiteralUtil::CreateR2FromArray2D<float>(Array2D<float>(10, 0))); } TEST_F(RoundTripTransferTest, R2F32_Len2x2) { - RoundTripTest(*LiteralUtil::CreateR2<float>({{42.0, 64.0}, {77.0, 88.0}})); + RoundTripTest(LiteralUtil::CreateR2<float>({{42.0, 64.0}, {77.0, 88.0}})); } TEST_F(RoundTripTransferTest, R3F32) { RoundTripTest( - *LiteralUtil::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, - {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}})); + LiteralUtil::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, + {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}})); } TEST_F(RoundTripTransferTest, R4F32) { - RoundTripTest(*LiteralUtil::CreateR4<float>({{ + RoundTripTest(LiteralUtil::CreateR4<float>({{ {{10, 11, 12, 13}, {14, 15, 16, 17}}, {{18, 19, 20, 21}, {22, 23, 24, 25}}, {{26, 27, 28, 29}, {30, 31, 32, 33}}, @@ -109,36 +107,35 @@ TEST_F(RoundTripTransferTest, R4F32) { } TEST_F(RoundTripTransferTest, EmptyTuple) { - RoundTripTest(*LiteralUtil::MakeTuple({})); + RoundTripTest(LiteralUtil::MakeTuple({})); } TEST_F(RoundTripTransferTest, TupleOfR1F32) { RoundTripTest( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1, 2}).get(), - LiteralUtil::CreateR1<float>({3, 4}).get()})); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({1, 2}), + LiteralUtil::CreateR1<float>({3, 4})})); } TEST_F(RoundTripTransferTest, TupleOfR1F32_Len0_Len2) { RoundTripTest( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({}).get(), - LiteralUtil::CreateR1<float>({3, 4}).get()})); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({}), + LiteralUtil::CreateR1<float>({3, 4})})); } TEST_F(RoundTripTransferTest, TupleOfR0F32AndR1S32) { - RoundTripTest( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(1.0).get(), - LiteralUtil::CreateR1<int>({2, 3}).get()})); + RoundTripTest(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0<float>(1.0), LiteralUtil::CreateR1<int>({2, 3})})); } // Below two tests are added to identify the cost of large data transfers. TEST_F(RoundTripTransferTest, R2F32_Large) { - RoundTripTest(*LiteralUtil::CreateR2F32Linspace(-1.0f, 1.0f, 512, 512)); + RoundTripTest(LiteralUtil::CreateR2F32Linspace(-1.0f, 1.0f, 512, 512)); } TEST_F(RoundTripTransferTest, R4F32_Large) { Array4D<float> array4d(2, 2, 256, 256); array4d.FillWithMultiples(1.0f); - RoundTripTest(*LiteralUtil::CreateR4FromArray4D<float>(array4d)); + RoundTripTest(LiteralUtil::CreateR4FromArray4D<float>(array4d)); } } // namespace |