diff options
-rw-r--r-- | tensorflow/compiler/xla/tests/client_library_test_base.cc | 12 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/client_library_test_base.h | 8 |
2 files changed, 19 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index d445ced7b0..7c9494f133 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -526,6 +526,15 @@ std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateParameterAndTransferLiteral( int64 parameter_number, const Literal& literal, const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle) { + return CreateParameterAndTransferLiteral(parameter_number, literal, name, + nullptr, builder, data_handle); +} + +std::unique_ptr<GlobalData> +ClientLibraryTestBase::CreateParameterAndTransferLiteral( + int64 parameter_number, const Literal& literal, const string& name, + const DeviceHandle* device_handle, ComputationBuilder* builder, + ComputationDataHandle* data_handle) { const Literal* param_literal = &literal; std::unique_ptr<Literal> converted_literal; if (use_bfloat16_) { @@ -533,7 +542,8 @@ ClientLibraryTestBase::CreateParameterAndTransferLiteral( param_literal = converted_literal.get(); } std::unique_ptr<GlobalData> data = - client_->TransferToServer(*param_literal).ConsumeValueOrDie(); + client_->TransferToServer(*param_literal, device_handle) + .ConsumeValueOrDie(); *data_handle = builder->Parameter(parameter_number, param_literal->shape(), name); return data; diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 92e16d6de4..a559a653df 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -270,6 +270,14 @@ class ClientLibraryTestBase : public ::testing::Test { int64 parameter_number, const Literal& literal, const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle); + // As above, but the caller can specify the device that the literal is + // transferred to. If device_handle is nullptr, the literal will be + // transferred to the default device. + std::unique_ptr<GlobalData> CreateParameterAndTransferLiteral( + int64 parameter_number, const Literal& literal, const string& name, + const DeviceHandle* device_handle, ComputationBuilder* builder, + ComputationDataHandle* data_handle); + // Creates a parameter instruction and sets the value that will be passed to // the computation as specified. This function must be used for all parameters // or none and no parameters must be passed when invoking the computation if |