diff options
Diffstat (limited to 'tensorflow/compiler/xla/tests/client_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/client_test.cc | 29 |
1 files changed, 14 insertions, 15 deletions
diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc index c898dacf48..6f2ca84bb6 100644 --- a/tensorflow/compiler/xla/tests/client_test.cc +++ b/tensorflow/compiler/xla/tests/client_test.cc @@ -55,16 +55,15 @@ XLA_TEST_F(ClientTest, ExecuteWithLayout) { std::unique_ptr<GlobalData> data, client_->Execute(computation, {}, &execution_options)); - std::unique_ptr<Literal> expected_literal = - LiteralUtil::CreateR2WithLayout<int32>( - {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout)); + Literal expected_literal = LiteralUtil::CreateR2WithLayout<int32>( + {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout)); TF_ASSERT_OK_AND_ASSIGN( - auto computed, client_->Transfer(*data, &expected_literal->shape())); + auto computed, client_->Transfer(*data, &expected_literal.shape())); ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts( - expected_literal->shape(), computed->shape())); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); + expected_literal.shape(), computed.shape())); + EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed)); } } } @@ -91,19 +90,19 @@ XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) { auto result, client_->ExecuteAndTransfer(computation, {}, &execution_options)); LiteralTestUtil::ExpectR2Equal<int32>({{1, 2}, {3, 4}}, - LiteralSlice(*result, {0})); + LiteralSlice(result, {0})); LiteralTestUtil::ExpectR2Equal<int32>({{10, 20}, {30, 40}}, - LiteralSlice(*result, {1})); + LiteralSlice(result, {1})); - EXPECT_TRUE(ShapeUtil::IsTuple(result->shape())); - EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->shape())); + EXPECT_TRUE(ShapeUtil::IsTuple(result.shape())); + EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.shape())); EXPECT_TRUE(ShapeUtil::Equal( - ShapeUtil::GetTupleElementShape(result->shape(), 0), + ShapeUtil::GetTupleElementShape(result.shape(), 0), ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2}, /*minor_to_major=*/{0, 1}))); EXPECT_TRUE(ShapeUtil::Equal( - ShapeUtil::GetTupleElementShape(result->shape(), 1), + ShapeUtil::GetTupleElementShape(result.shape(), 1), ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2}, /*minor_to_major=*/{1, 0}))); } @@ -114,7 +113,7 @@ XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> const_arg, client_->TransferToServer( - *LiteralUtil::CreateR2<int32>({{5, 6}, {7, 8}}))); + LiteralUtil::CreateR2<int32>({{5, 6}, {7, 8}}))); XlaBuilder b(TestName() + ".add"); Add(Parameter(&b, 0, shape, "param_0"), @@ -140,9 +139,9 @@ XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) { TF_ASSERT_OK_AND_ASSIGN( auto result_literal, - client_->Transfer(*results[0], &expected_result->shape())); + client_->Transfer(*results[0], &expected_result.shape())); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected_result, *result_literal)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected_result, result_literal)); } } // namespace |