aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/client_library_test_base.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/client_library_test_base.cc')
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.cc71
1 files changed, 35 insertions, 36 deletions
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc
index 8a236db0ff..fbdf0fcb65 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.cc
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc
@@ -101,7 +101,7 @@ StatusOr<std::unique_ptr<GlobalData>> ClientLibraryTestBase::Execute(
return client_->Execute(computation, arguments, &execution_options_);
}
-StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
+StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransfer(
const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
const Shape* shape_with_output_layout) {
ExecutionOptions execution_options = execution_options_;
@@ -113,7 +113,7 @@ StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
&execution_options);
}
-StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
+StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransfer(
XlaBuilder* builder, absl::Span<GlobalData* const> arguments,
const Shape* shape_with_output_layout) {
// Build the computation, as a convenience.
@@ -121,8 +121,7 @@ StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
return ExecuteAndTransfer(computation, arguments, shape_with_output_layout);
}
-StatusOr<std::unique_ptr<Literal>>
-ClientLibraryTestBase::ExecuteAndTransferReference(
+StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransferReference(
const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
const Shape* shape_with_output_layout) {
ExecutionOptions execution_options = execution_options_;
@@ -148,15 +147,15 @@ string ClientLibraryTestBase::ExecuteToString(
if (!result.ok()) {
return result.status().ToString();
} else {
- return result.ValueOrDie()->ToString();
+ return result.ValueOrDie().ToString();
}
}
void ClientLibraryTestBase::ComputeAndCompareR1(
XlaBuilder* builder, const tensorflow::core::Bitmap& expected,
absl::Span<GlobalData* const> arguments) {
- std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR1(expected);
- ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ Literal expected_literal = LiteralUtil::CreateR1(expected);
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments);
}
@@ -182,7 +181,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts(
const string& error_message)>& verify_output) {
// Try with no layout requirement.
TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments));
- verify_output(*actual, "");
+ verify_output(actual, "");
// Try with all output layouts.
std::vector<int64> minor_to_major(ShapeUtil::Rank(expected.shape()));
@@ -193,7 +192,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts(
AsInt64Slice(expected.shape().dimensions()), minor_to_major);
TF_ASSIGN_OR_RETURN(auto actual,
ExecuteAndTransfer(computation, arguments, &layout));
- verify_output(*actual,
+ verify_output(actual,
absl::StrCat("Test with output layout: ",
ShapeUtil::HumanStringWithLayout(layout)));
} while (std::next_permutation(minor_to_major.begin(), minor_to_major.end()));
@@ -218,9 +217,9 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
TF_ASSIGN_OR_RETURN(auto literal,
client_->Transfer(*arguments[index], nullptr));
// Skip tuples because they don't have a rank.
- if (ShapeUtil::IsTuple(literal->shape())) {
+ if (ShapeUtil::IsTuple(literal.shape())) {
layout_strings.push_back(
- ShapeUtil::HumanStringWithLayout(literal->shape()));
+ ShapeUtil::HumanStringWithLayout(literal.shape()));
arguments_with_layout.push_back(arguments[index]);
TF_RETURN_IF_ERROR(choose(index + 1));
arguments_with_layout.pop_back();
@@ -228,15 +227,15 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
return Status::OK();
}
- std::vector<int64> minor_to_major(ShapeUtil::Rank(literal->shape()));
+ std::vector<int64> minor_to_major(ShapeUtil::Rank(literal.shape()));
std::iota(minor_to_major.begin(), minor_to_major.end(), 0);
do {
auto literal_relayout =
- literal->Relayout(LayoutUtil::MakeLayout(minor_to_major));
+ literal.Relayout(LayoutUtil::MakeLayout(minor_to_major));
layout_strings.push_back(
- ShapeUtil::HumanStringWithLayout(literal_relayout->shape()));
+ ShapeUtil::HumanStringWithLayout(literal_relayout.shape()));
TF_ASSIGN_OR_RETURN(auto data,
- client_->TransferToServer(*literal_relayout));
+ client_->TransferToServer(literal_relayout));
arguments_with_layout.push_back(data.get());
TF_RETURN_IF_ERROR(choose(index + 1));
arguments_with_layout.pop_back();
@@ -256,7 +255,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
for (const auto& str : layout_strings) {
absl::StrAppend(&error_message, str, " ");
}
- verify_output(*actual, error_message);
+ verify_output(actual, error_message);
return Status::OK();
};
@@ -290,11 +289,11 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
// We allow using a float expected literal for a bfloat16 output. In this
// case, we need to convert the expected literal to bfloat16.
const Literal* expected_ptr = &expected;
- std::unique_ptr<Literal> converted_expected;
+ Literal converted_expected;
Shape layout_shape;
if (use_bfloat16_) {
converted_expected = LiteralUtil::ConvertF32ToBF16(expected);
- expected_ptr = converted_expected.get();
+ expected_ptr = &converted_expected;
if (shape_with_layout != nullptr) {
layout_shape = *shape_with_layout;
ShapeUtil::ForEachMutableSubshape(
@@ -319,7 +318,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
}
TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments,
shape_with_layout));
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, *actual));
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, actual));
return Status::OK();
}
@@ -346,11 +345,11 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
// We allow using a float expected literal for a bfloat16 output. In this
// case, we need to convert the expected literal to bfloat16.
const Literal* expected_ptr = &expected;
- std::unique_ptr<Literal> converted_expected;
+ Literal converted_expected;
Shape layout_shape;
if (use_bfloat16_) {
converted_expected = LiteralUtil::ConvertF32ToBF16(expected);
- expected_ptr = converted_expected.get();
+ expected_ptr = &converted_expected;
if (shape_with_layout != nullptr) {
layout_shape = *shape_with_layout;
ShapeUtil::ForEachMutableSubshape(
@@ -376,7 +375,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
}
TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments,
shape_with_layout));
- EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, *actual, error));
+ EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, actual, error));
return Status::OK();
}
@@ -391,12 +390,12 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8(
auto actual = actual_status.ConsumeValueOrDie();
// Turn the expected value into a literal.
- std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR1U8(expected);
+ Literal expected_literal = LiteralUtil::CreateR1U8(expected);
- VLOG(1) << "expected: " << expected_literal->ToString();
- VLOG(1) << "actual: " << actual->ToString();
+ VLOG(1) << "expected: " << expected_literal.ToString();
+ VLOG(1) << "actual: " << actual.ToString();
- EXPECT_EQ(expected, actual->GetR1U8AsString());
+ EXPECT_EQ(expected, actual.GetR1U8AsString());
}
void ClientLibraryTestBase::ComputeAndCompareTuple(
@@ -408,7 +407,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple(
return;
}
auto actual = actual_status.ConsumeValueOrDie();
- EXPECT_TRUE(LiteralTestUtil::Equal(expected, *actual));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual));
}
void ClientLibraryTestBase::ComputeAndCompareTuple(
@@ -420,7 +419,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple(
return;
}
auto actual = actual_status.ConsumeValueOrDie();
- EXPECT_TRUE(LiteralTestUtil::Near(expected, *actual, error));
+ EXPECT_TRUE(LiteralTestUtil::Near(expected, actual, error));
}
void ClientLibraryTestBase::ComputeAndCompare(
@@ -430,9 +429,9 @@ void ClientLibraryTestBase::ComputeAndCompare(
if (!status_or_data.ok()) {
return;
}
- std::unique_ptr<Literal> reference, result;
+ Literal reference, result;
std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
- EXPECT_TRUE(LiteralTestUtil::Equal(*reference, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(reference, result));
}
void ClientLibraryTestBase::ComputeAndCompare(
@@ -442,12 +441,12 @@ void ClientLibraryTestBase::ComputeAndCompare(
if (!status_or_data.ok()) {
return;
}
- std::unique_ptr<Literal> reference, result;
+ Literal reference, result;
std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
- EXPECT_TRUE(LiteralTestUtil::Near(*reference, *result, error));
+ EXPECT_TRUE(LiteralTestUtil::Near(reference, result, error));
}
-StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>>
+StatusOr<std::pair<Literal, Literal>>
ClientLibraryTestBase::ComputeValueAndReference(
XlaBuilder* builder, absl::Span<const Literal> arguments) {
// Transfer the arguments to the executor service. We put the unique_ptr's
@@ -569,8 +568,8 @@ XlaOp ClientLibraryTestBase::AddParam(const Literal& argument,
XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal,
XlaBuilder* builder) {
return ConstantLiteral(builder, use_bfloat16_
- ? *LiteralUtil::ConvertF32ToBF16(literal)
- : literal);
+ ? LiteralUtil::ConvertF32ToBF16(literal)
+ : LiteralSlice(literal));
}
std::unique_ptr<GlobalData>
@@ -600,7 +599,7 @@ Shape ClientLibraryTestBase::MaybeConvertShapeToBfloat16(const Shape& shape) {
Literal ClientLibraryTestBase::MaybeConvertLiteralToBfloat16(
const Literal& literal) {
if (use_bfloat16_) {
- return std::move(*LiteralUtil::ConvertF32ToBF16(literal));
+ return LiteralUtil::ConvertF32ToBF16(literal);
}
return literal.Clone();
}