diff options
Diffstat (limited to 'tensorflow/core/distributed_runtime/eager/remote_execute_node.h')
-rw-r--r-- | tensorflow/core/distributed_runtime/eager/remote_execute_node.h | 19 |
1 files changed, 9 insertions, 10 deletions
diff --git a/tensorflow/core/distributed_runtime/eager/remote_execute_node.h b/tensorflow/core/distributed_runtime/eager/remote_execute_node.h index 28b68c3b88..0e3a68c4d8 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_execute_node.h +++ b/tensorflow/core/distributed_runtime/eager/remote_execute_node.h @@ -29,8 +29,8 @@ namespace eager { class RemoteExecuteNode : public tensorflow::EagerNode { public: RemoteExecuteNode( - tensorflow::uint64 id, const tensorflow::eager::EnqueueRequest& request, - tensorflow::eager::EagerClient* eager_client, + tensorflow::uint64 id, std::unique_ptr<EnqueueRequest> request, + EagerClient* eager_client, const gtl::InlinedVector<TensorHandle*, 4>& inputs, std::function<void(const Status& status, const EnqueueResponse& response)> done_callback) @@ -45,8 +45,8 @@ class RemoteExecuteNode : public tensorflow::EagerNode { } RemoteExecuteNode(tensorflow::uint64 id, - const tensorflow::eager::EnqueueRequest& request, - tensorflow::eager::EagerClient* eager_client) + std::unique_ptr<EnqueueRequest> request, + EagerClient* eager_client) : tensorflow::EagerNode(id), request_(std::move(request)), eager_client_(eager_client) {} @@ -58,10 +58,10 @@ class RemoteExecuteNode : public tensorflow::EagerNode { } tensorflow::Status Run() override { - tensorflow::eager::EnqueueResponse response; - tensorflow::Status status; + EnqueueResponse response; + Status status; Notification n; - eager_client_->EnqueueAsync(&request_, &response, + eager_client_->EnqueueAsync(request_.get(), &response, [&n, &status](const tensorflow::Status& s) { status.Update(s); n.Notify(); @@ -76,9 +76,8 @@ class RemoteExecuteNode : public tensorflow::EagerNode { } private: - EnqueueRequest request_; - tensorflow::eager::EagerClient* - eager_client_; // Not owned, and must outlive the RemoteExecuteNode. + std::unique_ptr<EnqueueRequest> request_; + EagerClient* eager_client_; // Not owned, and must outlive this node. // This is required to ensure that the tensor handles stay alive across the // execution. |