aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/client_library_test_base.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/client_library_test_base.cc')
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.cc78
1 files changed, 58 insertions, 20 deletions
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc
index dafd6ebabb..515c0201d1 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.cc
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
@@ -157,7 +158,7 @@ string ClientLibraryTestBase::ExecuteToString(
void ClientLibraryTestBase::ComputeAndCompareR1(
XlaBuilder* builder, const tensorflow::core::Bitmap& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
- std::unique_ptr<Literal> expected_literal = Literal::CreateR1(expected);
+ std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR1(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
arguments);
}
@@ -273,10 +274,16 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
const Shape* shape_with_layout) {
std::vector<GlobalData*> arguments(arguments_passed_in.begin(),
arguments_passed_in.end());
+
+ // Transfer and use elements of arguments_, if the AddParam() API was used.
+ std::vector<std::unique_ptr<GlobalData>> owning_arguments;
if (!arguments_.empty()) {
CHECK(arguments.empty());
for (const auto& argument : arguments_) {
- arguments.push_back(argument.get());
+ owning_arguments.push_back(
+ client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument))
+ .ValueOrDie());
+ arguments.push_back(owning_arguments.back().get());
}
}
@@ -295,7 +302,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
std::unique_ptr<Literal> converted_expected;
Shape layout_shape;
if (use_bfloat16_) {
- converted_expected = Literal::ConvertF32ToBF16(expected);
+ converted_expected = LiteralUtil::ConvertF32ToBF16(expected);
expected_ptr = converted_expected.get();
if (shape_with_layout != nullptr) {
layout_shape = *shape_with_layout;
@@ -331,10 +338,16 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
ErrorSpec error, const Shape* shape_with_layout) {
std::vector<GlobalData*> arguments(arguments_passed_in.begin(),
arguments_passed_in.end());
+
+ // Transfer and use elements of arguments_, if the AddParam() API was used.
+ std::vector<std::unique_ptr<GlobalData>> owning_arguments;
if (!arguments_.empty()) {
CHECK(arguments.empty());
for (const auto& argument : arguments_) {
- arguments.push_back(argument.get());
+ owning_arguments.push_back(
+ client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument))
+ .ValueOrDie());
+ arguments.push_back(owning_arguments.back().get());
}
}
@@ -347,7 +360,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
std::unique_ptr<Literal> converted_expected;
Shape layout_shape;
if (use_bfloat16_) {
- converted_expected = Literal::ConvertF32ToBF16(expected);
+ converted_expected = LiteralUtil::ConvertF32ToBF16(expected);
expected_ptr = converted_expected.get();
if (shape_with_layout != nullptr) {
layout_shape = *shape_with_layout;
@@ -389,7 +402,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8(
auto actual = actual_status.ConsumeValueOrDie();
// Turn the expected value into a literal.
- std::unique_ptr<Literal> expected_literal = Literal::CreateR1U8(expected);
+ std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR1U8(expected);
VLOG(1) << "expected: " << expected_literal->ToString();
VLOG(1) << "actual: " << actual->ToString();
@@ -454,6 +467,14 @@ ClientLibraryTestBase::ComputeValueAndReference(
// function.
std::vector<std::unique_ptr<GlobalData>> argument_data;
std::vector<std::unique_ptr<GlobalData>> ref_argument_data;
+
+ // Use `arguments_` if the AddParam() API was used. Otherwise, use
+ // plain `arguments`.
+ if (!arguments_.empty()) {
+ CHECK_EQ(arguments.size(), 0);
+ arguments = arguments_;
+ }
+
for (const auto& arg : arguments) {
TF_ASSIGN_OR_RETURN(auto data, client_->TransferToServer(arg.Clone()));
TF_ASSIGN_OR_RETURN(auto ref_data, ref_client_->TransferToServer(arg));
@@ -552,16 +573,16 @@ ClientLibraryTestBase::CreatePatternedMatrixWithZeroPadding(int rows, int cols,
XlaOp ClientLibraryTestBase::AddParam(const Literal& argument,
XlaBuilder* builder) {
- XlaOp data_handle;
- arguments_.push_back(CreateParameterAndTransferLiteral(
- arguments_.size(), argument, "", builder, &data_handle));
- return data_handle;
+ arguments_.push_back(argument.Clone());
+ return Parameter(builder, /*parameter_number=*/arguments_.size() - 1,
+ MaybeConvertShapeToBfloat16(argument.shape()), "");
}
XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal,
XlaBuilder* builder) {
- return ConstantLiteral(
- builder, use_bfloat16_ ? *Literal::ConvertF32ToBF16(literal) : literal);
+ return ConstantLiteral(builder, use_bfloat16_
+ ? *LiteralUtil::ConvertF32ToBF16(literal)
+ : literal);
}
std::unique_ptr<GlobalData>
@@ -574,22 +595,39 @@ ClientLibraryTestBase::CreateParameterAndTransferLiteral(int64 parameter_number,
nullptr, builder, data_handle);
}
+Shape ClientLibraryTestBase::MaybeConvertShapeToBfloat16(const Shape& shape) {
+ if (!use_bfloat16_) {
+ return shape;
+ }
+ Shape new_shape = shape;
+ ShapeUtil::ForEachMutableSubshape(&new_shape,
+ [](Shape* subshape, const ShapeIndex&) {
+ if (subshape->element_type() == F32) {
+ subshape->set_element_type(BF16);
+ }
+ });
+ return new_shape;
+}
+
+Literal ClientLibraryTestBase::MaybeConvertLiteralToBfloat16(
+ const Literal& literal) {
+ if (use_bfloat16_) {
+ return std::move(*LiteralUtil::ConvertF32ToBF16(literal));
+ }
+ return literal.Clone();
+}
+
std::unique_ptr<GlobalData>
ClientLibraryTestBase::CreateParameterAndTransferLiteral(
int64 parameter_number, const Literal& literal, const string& name,
const DeviceHandle* device_handle, XlaBuilder* builder,
XlaOp* data_handle) {
- const Literal* param_literal = &literal;
- std::unique_ptr<Literal> converted_literal;
- if (use_bfloat16_) {
- converted_literal = Literal::ConvertF32ToBF16(literal);
- param_literal = converted_literal.get();
- }
+ Literal param_literal = MaybeConvertLiteralToBfloat16(literal);
std::unique_ptr<GlobalData> data =
- client_->TransferToServer(*param_literal, device_handle)
+ client_->TransferToServer(param_literal, device_handle)
.ConsumeValueOrDie();
*data_handle =
- Parameter(builder, parameter_number, param_literal->shape(), name);
+ Parameter(builder, parameter_number, param_literal.shape(), name);
return data;
}