diff options
Diffstat (limited to 'tensorflow/compiler/xla/tests/client_library_test_base.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/client_library_test_base.cc | 71 |
1 files changed, 35 insertions, 36 deletions
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index 8a236db0ff..fbdf0fcb65 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -101,7 +101,7 @@ StatusOr<std::unique_ptr<GlobalData>> ClientLibraryTestBase::Execute( return client_->Execute(computation, arguments, &execution_options_); } -StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer( +StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransfer( const XlaComputation& computation, absl::Span<GlobalData* const> arguments, const Shape* shape_with_output_layout) { ExecutionOptions execution_options = execution_options_; @@ -113,7 +113,7 @@ StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer( &execution_options); } -StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer( +StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransfer( XlaBuilder* builder, absl::Span<GlobalData* const> arguments, const Shape* shape_with_output_layout) { // Build the computation, as a convenience. @@ -121,8 +121,7 @@ StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer( return ExecuteAndTransfer(computation, arguments, shape_with_output_layout); } -StatusOr<std::unique_ptr<Literal>> -ClientLibraryTestBase::ExecuteAndTransferReference( +StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransferReference( const XlaComputation& computation, absl::Span<GlobalData* const> arguments, const Shape* shape_with_output_layout) { ExecutionOptions execution_options = execution_options_; @@ -148,15 +147,15 @@ string ClientLibraryTestBase::ExecuteToString( if (!result.ok()) { return result.status().ToString(); } else { - return result.ValueOrDie()->ToString(); + return result.ValueOrDie().ToString(); } } void ClientLibraryTestBase::ComputeAndCompareR1( XlaBuilder* builder, const tensorflow::core::Bitmap& expected, absl::Span<GlobalData* const> arguments) { - std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR1(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + Literal expected_literal = LiteralUtil::CreateR1(expected); + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments); } @@ -182,7 +181,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( const string& error_message)>& verify_output) { // Try with no layout requirement. TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments)); - verify_output(*actual, ""); + verify_output(actual, ""); // Try with all output layouts. std::vector<int64> minor_to_major(ShapeUtil::Rank(expected.shape())); @@ -193,7 +192,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( AsInt64Slice(expected.shape().dimensions()), minor_to_major); TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, &layout)); - verify_output(*actual, + verify_output(actual, absl::StrCat("Test with output layout: ", ShapeUtil::HumanStringWithLayout(layout))); } while (std::next_permutation(minor_to_major.begin(), minor_to_major.end())); @@ -218,9 +217,9 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( TF_ASSIGN_OR_RETURN(auto literal, client_->Transfer(*arguments[index], nullptr)); // Skip tuples because they don't have a rank. - if (ShapeUtil::IsTuple(literal->shape())) { + if (ShapeUtil::IsTuple(literal.shape())) { layout_strings.push_back( - ShapeUtil::HumanStringWithLayout(literal->shape())); + ShapeUtil::HumanStringWithLayout(literal.shape())); arguments_with_layout.push_back(arguments[index]); TF_RETURN_IF_ERROR(choose(index + 1)); arguments_with_layout.pop_back(); @@ -228,15 +227,15 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( return Status::OK(); } - std::vector<int64> minor_to_major(ShapeUtil::Rank(literal->shape())); + std::vector<int64> minor_to_major(ShapeUtil::Rank(literal.shape())); std::iota(minor_to_major.begin(), minor_to_major.end(), 0); do { auto literal_relayout = - literal->Relayout(LayoutUtil::MakeLayout(minor_to_major)); + literal.Relayout(LayoutUtil::MakeLayout(minor_to_major)); layout_strings.push_back( - ShapeUtil::HumanStringWithLayout(literal_relayout->shape())); + ShapeUtil::HumanStringWithLayout(literal_relayout.shape())); TF_ASSIGN_OR_RETURN(auto data, - client_->TransferToServer(*literal_relayout)); + client_->TransferToServer(literal_relayout)); arguments_with_layout.push_back(data.get()); TF_RETURN_IF_ERROR(choose(index + 1)); arguments_with_layout.pop_back(); @@ -256,7 +255,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( for (const auto& str : layout_strings) { absl::StrAppend(&error_message, str, " "); } - verify_output(*actual, error_message); + verify_output(actual, error_message); return Status::OK(); }; @@ -290,11 +289,11 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( // We allow using a float expected literal for a bfloat16 output. In this // case, we need to convert the expected literal to bfloat16. const Literal* expected_ptr = &expected; - std::unique_ptr<Literal> converted_expected; + Literal converted_expected; Shape layout_shape; if (use_bfloat16_) { converted_expected = LiteralUtil::ConvertF32ToBF16(expected); - expected_ptr = converted_expected.get(); + expected_ptr = &converted_expected; if (shape_with_layout != nullptr) { layout_shape = *shape_with_layout; ShapeUtil::ForEachMutableSubshape( @@ -319,7 +318,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( } TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, shape_with_layout)); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, *actual)); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, actual)); return Status::OK(); } @@ -346,11 +345,11 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( // We allow using a float expected literal for a bfloat16 output. In this // case, we need to convert the expected literal to bfloat16. const Literal* expected_ptr = &expected; - std::unique_ptr<Literal> converted_expected; + Literal converted_expected; Shape layout_shape; if (use_bfloat16_) { converted_expected = LiteralUtil::ConvertF32ToBF16(expected); - expected_ptr = converted_expected.get(); + expected_ptr = &converted_expected; if (shape_with_layout != nullptr) { layout_shape = *shape_with_layout; ShapeUtil::ForEachMutableSubshape( @@ -376,7 +375,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( } TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, shape_with_layout)); - EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, *actual, error)); + EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, actual, error)); return Status::OK(); } @@ -391,12 +390,12 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8( auto actual = actual_status.ConsumeValueOrDie(); // Turn the expected value into a literal. - std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR1U8(expected); + Literal expected_literal = LiteralUtil::CreateR1U8(expected); - VLOG(1) << "expected: " << expected_literal->ToString(); - VLOG(1) << "actual: " << actual->ToString(); + VLOG(1) << "expected: " << expected_literal.ToString(); + VLOG(1) << "actual: " << actual.ToString(); - EXPECT_EQ(expected, actual->GetR1U8AsString()); + EXPECT_EQ(expected, actual.GetR1U8AsString()); } void ClientLibraryTestBase::ComputeAndCompareTuple( @@ -408,7 +407,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple( return; } auto actual = actual_status.ConsumeValueOrDie(); - EXPECT_TRUE(LiteralTestUtil::Equal(expected, *actual)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual)); } void ClientLibraryTestBase::ComputeAndCompareTuple( @@ -420,7 +419,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple( return; } auto actual = actual_status.ConsumeValueOrDie(); - EXPECT_TRUE(LiteralTestUtil::Near(expected, *actual, error)); + EXPECT_TRUE(LiteralTestUtil::Near(expected, actual, error)); } void ClientLibraryTestBase::ComputeAndCompare( @@ -430,9 +429,9 @@ void ClientLibraryTestBase::ComputeAndCompare( if (!status_or_data.ok()) { return; } - std::unique_ptr<Literal> reference, result; + Literal reference, result; std::tie(reference, result) = status_or_data.ConsumeValueOrDie(); - EXPECT_TRUE(LiteralTestUtil::Equal(*reference, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(reference, result)); } void ClientLibraryTestBase::ComputeAndCompare( @@ -442,12 +441,12 @@ void ClientLibraryTestBase::ComputeAndCompare( if (!status_or_data.ok()) { return; } - std::unique_ptr<Literal> reference, result; + Literal reference, result; std::tie(reference, result) = status_or_data.ConsumeValueOrDie(); - EXPECT_TRUE(LiteralTestUtil::Near(*reference, *result, error)); + EXPECT_TRUE(LiteralTestUtil::Near(reference, result, error)); } -StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>> +StatusOr<std::pair<Literal, Literal>> ClientLibraryTestBase::ComputeValueAndReference( XlaBuilder* builder, absl::Span<const Literal> arguments) { // Transfer the arguments to the executor service. We put the unique_ptr's @@ -569,8 +568,8 @@ XlaOp ClientLibraryTestBase::AddParam(const Literal& argument, XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal, XlaBuilder* builder) { return ConstantLiteral(builder, use_bfloat16_ - ? *LiteralUtil::ConvertF32ToBF16(literal) - : literal); + ? LiteralUtil::ConvertF32ToBF16(literal) + : LiteralSlice(literal)); } std::unique_ptr<GlobalData> @@ -600,7 +599,7 @@ Shape ClientLibraryTestBase::MaybeConvertShapeToBfloat16(const Shape& shape) { Literal ClientLibraryTestBase::MaybeConvertLiteralToBfloat16( const Literal& literal) { if (use_bfloat16_) { - return std::move(*LiteralUtil::ConvertF32ToBF16(literal)); + return LiteralUtil::ConvertF32ToBF16(literal); } return literal.Clone(); } |