diff options
Diffstat (limited to 'tensorflow/compiler/xla/client/client.cc')
-rw-r--r-- | tensorflow/compiler/xla/client/client.cc | 32 |
1 files changed, 32 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index 341c02f1a1..c4430dab65 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -132,6 +132,38 @@ Status Client::TransferToInfeed(const Literal& literal, int64 replica_id, return Status::OK(); } +StatusOr<std::unique_ptr<Literal>> Client::TransferFromOutfeed( + const Shape* shape_with_layout, int64 replica_id, + const DeviceHandle* device_handle) { + TransferFromOutfeedRequest request; + if (device_handle) { + *request.mutable_device_handle() = *device_handle; + } + request.set_replica_id(replica_id); + if (shape_with_layout != nullptr) { + *request.mutable_shape_with_layout() = *shape_with_layout; + } + TransferFromOutfeedResponse response; + + VLOG(1) << "making transfer from outfeed request"; + VLOG(3) << "TransferFromOutfeedRequest: {" << request.DebugString() << "}"; + Status s = stub_->TransferFromOutfeed(&request, &response); + VLOG(1) << "done with request"; + + if (!s.ok()) { + return s; + } + VLOG(3) << "TransferFromOutfeedResponse: {" << response.DebugString() << "}"; + + if (!response.has_literal()) { + return FailedPrecondition( + "server provided response without a literal in " + "TransferToClient request"); + } + + return WrapUnique(response.release_literal()); +} + Status Client::ResetDevice() { ResetDeviceRequest request; ResetDeviceResponse response; |