aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/client/client.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/client/client.cc')
-rw-r--r--tensorflow/compiler/xla/client/client.cc32
1 files changed, 32 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc
index 341c02f1a1..c4430dab65 100644
--- a/tensorflow/compiler/xla/client/client.cc
+++ b/tensorflow/compiler/xla/client/client.cc
@@ -132,6 +132,38 @@ Status Client::TransferToInfeed(const Literal& literal, int64 replica_id,
return Status::OK();
}
+StatusOr<std::unique_ptr<Literal>> Client::TransferFromOutfeed(
+ const Shape* shape_with_layout, int64 replica_id,
+ const DeviceHandle* device_handle) {
+ TransferFromOutfeedRequest request;
+ if (device_handle) {
+ *request.mutable_device_handle() = *device_handle;
+ }
+ request.set_replica_id(replica_id);
+ if (shape_with_layout != nullptr) {
+ *request.mutable_shape_with_layout() = *shape_with_layout;
+ }
+ TransferFromOutfeedResponse response;
+
+ VLOG(1) << "making transfer from outfeed request";
+ VLOG(3) << "TransferFromOutfeedRequest: {" << request.DebugString() << "}";
+ Status s = stub_->TransferFromOutfeed(&request, &response);
+ VLOG(1) << "done with request";
+
+ if (!s.ok()) {
+ return s;
+ }
+ VLOG(3) << "TransferFromOutfeedResponse: {" << response.DebugString() << "}";
+
+ if (!response.has_literal()) {
+ return FailedPrecondition(
+ "server provided response without a literal in "
+ "TransferToClient request");
+ }
+
+ return WrapUnique(response.release_literal());
+}
+
Status Client::ResetDevice() {
ResetDeviceRequest request;
ResetDeviceResponse response;