aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/distributed_runtime/eager/remote_execute_node.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/distributed_runtime/eager/remote_execute_node.h')
-rw-r--r--tensorflow/core/distributed_runtime/eager/remote_execute_node.h19
1 files changed, 9 insertions, 10 deletions
diff --git a/tensorflow/core/distributed_runtime/eager/remote_execute_node.h b/tensorflow/core/distributed_runtime/eager/remote_execute_node.h
index 28b68c3b88..0e3a68c4d8 100644
--- a/tensorflow/core/distributed_runtime/eager/remote_execute_node.h
+++ b/tensorflow/core/distributed_runtime/eager/remote_execute_node.h
@@ -29,8 +29,8 @@ namespace eager {
class RemoteExecuteNode : public tensorflow::EagerNode {
public:
RemoteExecuteNode(
- tensorflow::uint64 id, const tensorflow::eager::EnqueueRequest& request,
- tensorflow::eager::EagerClient* eager_client,
+ tensorflow::uint64 id, std::unique_ptr<EnqueueRequest> request,
+ EagerClient* eager_client,
const gtl::InlinedVector<TensorHandle*, 4>& inputs,
std::function<void(const Status& status, const EnqueueResponse& response)>
done_callback)
@@ -45,8 +45,8 @@ class RemoteExecuteNode : public tensorflow::EagerNode {
}
RemoteExecuteNode(tensorflow::uint64 id,
- const tensorflow::eager::EnqueueRequest& request,
- tensorflow::eager::EagerClient* eager_client)
+ std::unique_ptr<EnqueueRequest> request,
+ EagerClient* eager_client)
: tensorflow::EagerNode(id),
request_(std::move(request)),
eager_client_(eager_client) {}
@@ -58,10 +58,10 @@ class RemoteExecuteNode : public tensorflow::EagerNode {
}
tensorflow::Status Run() override {
- tensorflow::eager::EnqueueResponse response;
- tensorflow::Status status;
+ EnqueueResponse response;
+ Status status;
Notification n;
- eager_client_->EnqueueAsync(&request_, &response,
+ eager_client_->EnqueueAsync(request_.get(), &response,
[&n, &status](const tensorflow::Status& s) {
status.Update(s);
n.Notify();
@@ -76,9 +76,8 @@ class RemoteExecuteNode : public tensorflow::EagerNode {
}
private:
- EnqueueRequest request_;
- tensorflow::eager::EagerClient*
- eager_client_; // Not owned, and must outlive the RemoteExecuteNode.
+ std::unique_ptr<EnqueueRequest> request_;
+ EagerClient* eager_client_; // Not owned, and must outlive this node.
// This is required to ensure that the tensor handles stay alive across the
// execution.