diff options
author | Justin Lebar <jlebar@google.com> | 2018-07-20 15:15:44 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-20 15:22:45 -0700 |
commit | 62a10974897c3cdc929a079f389f6770c767377a (patch) | |
tree | 5b570d5eaca84424297d2101d4d59b70f2139a1b | |
parent | a3d814cd9100556a6e2e1468f1a9981820b4203c (diff) |
[XLA] Make ClientLibraryTestBase::AddParam work with the reference backend.
Previously, AddParam only worked with the "real" backend -- we'd never
pass the parameters to the reference backend, so it would always fail.
PiperOrigin-RevId: 205461805
-rw-r--r-- | tensorflow/compiler/xla/tests/client_library_test_base.cc | 64 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/client_library_test_base.h | 6 |
2 files changed, 55 insertions, 15 deletions
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index ef784da457..7a2e70d39f 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -273,10 +273,16 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( const Shape* shape_with_layout) { std::vector<GlobalData*> arguments(arguments_passed_in.begin(), arguments_passed_in.end()); + + // Transfer and use elements of arguments_, if the AddParam() API was used. + std::vector<std::unique_ptr<GlobalData>> owning_arguments; if (!arguments_.empty()) { CHECK(arguments.empty()); for (const auto& argument : arguments_) { - arguments.push_back(argument.get()); + owning_arguments.push_back( + client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument)) + .ValueOrDie()); + arguments.push_back(owning_arguments.back().get()); } } @@ -331,10 +337,16 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( ErrorSpec error, const Shape* shape_with_layout) { std::vector<GlobalData*> arguments(arguments_passed_in.begin(), arguments_passed_in.end()); + + // Transfer and use elements of arguments_, if the AddParam() API was used. + std::vector<std::unique_ptr<GlobalData>> owning_arguments; if (!arguments_.empty()) { CHECK(arguments.empty()); for (const auto& argument : arguments_) { - arguments.push_back(argument.get()); + owning_arguments.push_back( + client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument)) + .ValueOrDie()); + arguments.push_back(owning_arguments.back().get()); } } @@ -454,6 +466,14 @@ ClientLibraryTestBase::ComputeValueAndReference( // function. std::vector<std::unique_ptr<GlobalData>> argument_data; std::vector<std::unique_ptr<GlobalData>> ref_argument_data; + + // Use `arguments_` if the AddParam() API was used. Otherwise, use + // plain `arguments`. + if (!arguments_.empty()) { + CHECK_EQ(arguments.size(), 0); + arguments = arguments_; + } + for (const auto& arg : arguments) { TF_ASSIGN_OR_RETURN(auto data, client_->TransferToServer(arg.Clone())); TF_ASSIGN_OR_RETURN(auto ref_data, ref_client_->TransferToServer(arg)); @@ -552,10 +572,9 @@ ClientLibraryTestBase::CreatePatternedMatrixWithZeroPadding(int rows, int cols, XlaOp ClientLibraryTestBase::AddParam(const Literal& argument, XlaBuilder* builder) { - XlaOp data_handle; - arguments_.push_back(CreateParameterAndTransferLiteral( - arguments_.size(), argument, "", builder, &data_handle)); - return data_handle; + arguments_.push_back(argument.Clone()); + return Parameter(builder, /*parameter_number=*/arguments_.size() - 1, + MaybeConvertShapeToBfloat16(argument.shape()), ""); } XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal, @@ -575,22 +594,39 @@ ClientLibraryTestBase::CreateParameterAndTransferLiteral(int64 parameter_number, nullptr, builder, data_handle); } +Shape ClientLibraryTestBase::MaybeConvertShapeToBfloat16(const Shape& shape) { + if (!use_bfloat16_) { + return shape; + } + Shape new_shape = shape; + ShapeUtil::ForEachMutableSubshape(&new_shape, + [](Shape* subshape, const ShapeIndex&) { + if (subshape->element_type() == F32) { + subshape->set_element_type(BF16); + } + }); + return new_shape; +} + +Literal ClientLibraryTestBase::MaybeConvertLiteralToBfloat16( + const Literal& literal) { + if (use_bfloat16_) { + return std::move(*LiteralUtil::ConvertF32ToBF16(literal)); + } + return literal.Clone(); +} + std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateParameterAndTransferLiteral( int64 parameter_number, const Literal& literal, const string& name, const DeviceHandle* device_handle, XlaBuilder* builder, XlaOp* data_handle) { - const Literal* param_literal = &literal; - std::unique_ptr<Literal> converted_literal; - if (use_bfloat16_) { - converted_literal = LiteralUtil::ConvertF32ToBF16(literal); - param_literal = converted_literal.get(); - } + Literal param_literal = MaybeConvertLiteralToBfloat16(literal); std::unique_ptr<GlobalData> data = - client_->TransferToServer(*param_literal, device_handle) + client_->TransferToServer(param_literal, device_handle) .ConsumeValueOrDie(); *data_handle = - Parameter(builder, parameter_number, param_literal->shape(), name); + Parameter(builder, parameter_number, param_literal.shape(), name); return data; } diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index fcc9347db5..f0f7ff1ea0 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -399,12 +399,16 @@ class ClientLibraryTestBase : public ::testing::Test { const string& error_message)>& verify_output, const Shape* output_with_layout = nullptr); + // Converts an f32 shape/literal to bf16 if use_bfloat16_ is true. + Literal MaybeConvertLiteralToBfloat16(const Literal& literal); + Shape MaybeConvertShapeToBfloat16(const Shape& shape); + // Whether to run tests with all float-type input/output converted to // bfloat16. bool use_bfloat16_ = false; // Arguments to be passed to the computation when it runs. - std::vector<std::unique_ptr<GlobalData>> arguments_; + std::vector<Literal> arguments_; }; template <typename NativeT> |