aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.cc12
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.h8
2 files changed, 19 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc
index d445ced7b0..7c9494f133 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.cc
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc
@@ -526,6 +526,15 @@ std::unique_ptr<GlobalData>
ClientLibraryTestBase::CreateParameterAndTransferLiteral(
int64 parameter_number, const Literal& literal, const string& name,
ComputationBuilder* builder, ComputationDataHandle* data_handle) {
+ return CreateParameterAndTransferLiteral(parameter_number, literal, name,
+ nullptr, builder, data_handle);
+}
+
+std::unique_ptr<GlobalData>
+ClientLibraryTestBase::CreateParameterAndTransferLiteral(
+ int64 parameter_number, const Literal& literal, const string& name,
+ const DeviceHandle* device_handle, ComputationBuilder* builder,
+ ComputationDataHandle* data_handle) {
const Literal* param_literal = &literal;
std::unique_ptr<Literal> converted_literal;
if (use_bfloat16_) {
@@ -533,7 +542,8 @@ ClientLibraryTestBase::CreateParameterAndTransferLiteral(
param_literal = converted_literal.get();
}
std::unique_ptr<GlobalData> data =
- client_->TransferToServer(*param_literal).ConsumeValueOrDie();
+ client_->TransferToServer(*param_literal, device_handle)
+ .ConsumeValueOrDie();
*data_handle =
builder->Parameter(parameter_number, param_literal->shape(), name);
return data;
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h
index 92e16d6de4..a559a653df 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.h
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.h
@@ -270,6 +270,14 @@ class ClientLibraryTestBase : public ::testing::Test {
int64 parameter_number, const Literal& literal, const string& name,
ComputationBuilder* builder, ComputationDataHandle* data_handle);
+ // As above, but the caller can specify the device that the literal is
+ // transferred to. If device_handle is nullptr, the literal will be
+ // transferred to the default device.
+ std::unique_ptr<GlobalData> CreateParameterAndTransferLiteral(
+ int64 parameter_number, const Literal& literal, const string& name,
+ const DeviceHandle* device_handle, ComputationBuilder* builder,
+ ComputationDataHandle* data_handle);
+
// Creates a parameter instruction and sets the value that will be passed to
// the computation as specified. This function must be used for all parameters
// or none and no parameters must be passed when invoking the computation if