aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Justin Lebar <jlebar@google.com>2018-07-20 15:15:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-20 15:22:45 -0700
commit62a10974897c3cdc929a079f389f6770c767377a (patch)
tree5b570d5eaca84424297d2101d4d59b70f2139a1b
parenta3d814cd9100556a6e2e1468f1a9981820b4203c (diff)
[XLA] Make ClientLibraryTestBase::AddParam work with the reference backend.
Previously, AddParam only worked with the "real" backend -- we'd never pass the parameters to the reference backend, so it would always fail. PiperOrigin-RevId: 205461805
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.cc64
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.h6
2 files changed, 55 insertions, 15 deletions
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc
index ef784da457..7a2e70d39f 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.cc
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc
@@ -273,10 +273,16 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
const Shape* shape_with_layout) {
std::vector<GlobalData*> arguments(arguments_passed_in.begin(),
arguments_passed_in.end());
+
+ // Transfer and use elements of arguments_, if the AddParam() API was used.
+ std::vector<std::unique_ptr<GlobalData>> owning_arguments;
if (!arguments_.empty()) {
CHECK(arguments.empty());
for (const auto& argument : arguments_) {
- arguments.push_back(argument.get());
+ owning_arguments.push_back(
+ client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument))
+ .ValueOrDie());
+ arguments.push_back(owning_arguments.back().get());
}
}
@@ -331,10 +337,16 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
ErrorSpec error, const Shape* shape_with_layout) {
std::vector<GlobalData*> arguments(arguments_passed_in.begin(),
arguments_passed_in.end());
+
+ // Transfer and use elements of arguments_, if the AddParam() API was used.
+ std::vector<std::unique_ptr<GlobalData>> owning_arguments;
if (!arguments_.empty()) {
CHECK(arguments.empty());
for (const auto& argument : arguments_) {
- arguments.push_back(argument.get());
+ owning_arguments.push_back(
+ client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument))
+ .ValueOrDie());
+ arguments.push_back(owning_arguments.back().get());
}
}
@@ -454,6 +466,14 @@ ClientLibraryTestBase::ComputeValueAndReference(
// function.
std::vector<std::unique_ptr<GlobalData>> argument_data;
std::vector<std::unique_ptr<GlobalData>> ref_argument_data;
+
+ // Use `arguments_` if the AddParam() API was used. Otherwise, use
+ // plain `arguments`.
+ if (!arguments_.empty()) {
+ CHECK_EQ(arguments.size(), 0);
+ arguments = arguments_;
+ }
+
for (const auto& arg : arguments) {
TF_ASSIGN_OR_RETURN(auto data, client_->TransferToServer(arg.Clone()));
TF_ASSIGN_OR_RETURN(auto ref_data, ref_client_->TransferToServer(arg));
@@ -552,10 +572,9 @@ ClientLibraryTestBase::CreatePatternedMatrixWithZeroPadding(int rows, int cols,
XlaOp ClientLibraryTestBase::AddParam(const Literal& argument,
XlaBuilder* builder) {
- XlaOp data_handle;
- arguments_.push_back(CreateParameterAndTransferLiteral(
- arguments_.size(), argument, "", builder, &data_handle));
- return data_handle;
+ arguments_.push_back(argument.Clone());
+ return Parameter(builder, /*parameter_number=*/arguments_.size() - 1,
+ MaybeConvertShapeToBfloat16(argument.shape()), "");
}
XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal,
@@ -575,22 +594,39 @@ ClientLibraryTestBase::CreateParameterAndTransferLiteral(int64 parameter_number,
nullptr, builder, data_handle);
}
+Shape ClientLibraryTestBase::MaybeConvertShapeToBfloat16(const Shape& shape) {
+ if (!use_bfloat16_) {
+ return shape;
+ }
+ Shape new_shape = shape;
+ ShapeUtil::ForEachMutableSubshape(&new_shape,
+ [](Shape* subshape, const ShapeIndex&) {
+ if (subshape->element_type() == F32) {
+ subshape->set_element_type(BF16);
+ }
+ });
+ return new_shape;
+}
+
+Literal ClientLibraryTestBase::MaybeConvertLiteralToBfloat16(
+ const Literal& literal) {
+ if (use_bfloat16_) {
+ return std::move(*LiteralUtil::ConvertF32ToBF16(literal));
+ }
+ return literal.Clone();
+}
+
std::unique_ptr<GlobalData>
ClientLibraryTestBase::CreateParameterAndTransferLiteral(
int64 parameter_number, const Literal& literal, const string& name,
const DeviceHandle* device_handle, XlaBuilder* builder,
XlaOp* data_handle) {
- const Literal* param_literal = &literal;
- std::unique_ptr<Literal> converted_literal;
- if (use_bfloat16_) {
- converted_literal = LiteralUtil::ConvertF32ToBF16(literal);
- param_literal = converted_literal.get();
- }
+ Literal param_literal = MaybeConvertLiteralToBfloat16(literal);
std::unique_ptr<GlobalData> data =
- client_->TransferToServer(*param_literal, device_handle)
+ client_->TransferToServer(param_literal, device_handle)
.ConsumeValueOrDie();
*data_handle =
- Parameter(builder, parameter_number, param_literal->shape(), name);
+ Parameter(builder, parameter_number, param_literal.shape(), name);
return data;
}
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h
index fcc9347db5..f0f7ff1ea0 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.h
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.h
@@ -399,12 +399,16 @@ class ClientLibraryTestBase : public ::testing::Test {
const string& error_message)>& verify_output,
const Shape* output_with_layout = nullptr);
+ // Converts an f32 shape/literal to bf16 if use_bfloat16_ is true.
+ Literal MaybeConvertLiteralToBfloat16(const Literal& literal);
+ Shape MaybeConvertShapeToBfloat16(const Shape& shape);
+
// Whether to run tests with all float-type input/output converted to
// bfloat16.
bool use_bfloat16_ = false;
// Arguments to be passed to the computation when it runs.
- std::vector<std::unique_ptr<GlobalData>> arguments_;
+ std::vector<Literal> arguments_;
};
template <typename NativeT>