From fdff4048d4d0fdf7c12f927b92bb5e2fb812df12 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Wed, 15 Nov 2017 14:07:41 -0800 Subject: Add `WorkerService.DeleteWorkerSession` method to fix a memory leak. The new method is the counterpart to `WorkerService.CreateWorkerSession`, and is called in all cases where worker sessions have been explicitly created (i.e. when using ClusterSpec propagation). PiperOrigin-RevId: 175877407 --- .../core/distributed_runtime/master_session.cc | 60 ++++++++++++++++++++++ .../core/distributed_runtime/master_session.h | 4 ++ .../distributed_runtime/rpc/grpc_remote_worker.cc | 8 +++ .../distributed_runtime/rpc/grpc_worker_service.cc | 11 ++++ .../rpc/grpc_worker_service_impl.cc | 2 + .../rpc/grpc_worker_service_impl.h | 1 + tensorflow/core/distributed_runtime/worker.cc | 7 +++ tensorflow/core/distributed_runtime/worker.h | 4 ++ .../core/distributed_runtime/worker_interface.h | 9 ++++ tensorflow/core/protobuf/worker.proto | 16 ++++++ tensorflow/core/protobuf/worker_service.proto | 4 ++ 11 files changed, 126 insertions(+) diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index 5798ad09e8..91a1fa7d1e 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -1044,6 +1044,7 @@ Status MasterSession::Create(GraphDef* graph_def, graph_def, execution_options, &execution_state_)); } if (options.cluster_def != nullptr) { + should_delete_worker_sessions_ = true; return CreateWorkerSessions(options); } return Status::OK(); @@ -1122,6 +1123,59 @@ Status MasterSession::CreateWorkerSessions( return status; } +Status MasterSession::DeleteWorkerSessions() { + WorkerCacheInterface* worker_cache = get_worker_cache(); + std::vector worker_names; + worker_cache->ListWorkers(&worker_names); + + struct WorkerGroup { + // The worker name. (Not owned.) + const string* name; + + // The worker referenced by name. (Not owned.) + WorkerInterface* worker = nullptr; + + // Request and responses used for a given worker. + DeleteWorkerSessionRequest request; + DeleteWorkerSessionResponse response; + Status status = Status::OK(); + }; + BlockingCounter done(worker_names.size()); + std::vector workers(worker_names.size()); + + // Release the workers. + auto cleanup = gtl::MakeCleanup([this, &workers, worker_cache] { + for (auto&& worker_group : workers) { + if (worker_group.worker != nullptr) { + worker_cache->ReleaseWorker(*worker_group.name, worker_group.worker); + } + } + }); + + Status status = Status::OK(); + // Create all the workers & kick off the computations. + for (size_t i = 0; i < worker_names.size(); ++i) { + workers[i].name = &worker_names[i]; + workers[i].worker = worker_cache_->CreateWorker(worker_names[i]); + workers[i].request.set_session_handle(handle_); + } + + for (size_t i = 0; i < worker_names.size(); ++i) { + auto cb = [i, &workers, &done](const Status& s) { + workers[i].status = s; + done.DecrementCount(); + }; + workers[i].worker->DeleteWorkerSessionAsync(&workers[i].request, + &workers[i].response, cb); + } + + done.Wait(); + for (size_t i = 0; i < workers.size(); ++i) { + status.Update(workers[i].status); + } + return status; +} + Status MasterSession::ListDevices(ListDevicesResponse* resp) const { if (worker_cache_) { // This is a ClusterSpec-propagated session, and thus env_->local_devices @@ -1604,6 +1658,12 @@ Status MasterSession::Close() { ClearRunsTable(&to_unref, &partial_run_graphs_); } for (ReffedClientGraph* rcg : to_unref) rcg->Unref(); + if (should_delete_worker_sessions_) { + Status s = DeleteWorkerSessions(); + if (!s.ok()) { + LOG(WARNING) << s; + } + } return Status::OK(); } diff --git a/tensorflow/core/distributed_runtime/master_session.h b/tensorflow/core/distributed_runtime/master_session.h index eb696eb06a..4bd4e1367a 100644 --- a/tensorflow/core/distributed_runtime/master_session.h +++ b/tensorflow/core/distributed_runtime/master_session.h @@ -201,6 +201,10 @@ class MasterSession : public core::RefCounted { // workers. Status CreateWorkerSessions(const WorkerCacheFactoryOptions& server_def); + // TODO(b/36574172): Always use Create/DeleteWorkerSession. + bool should_delete_worker_sessions_ = false; + Status DeleteWorkerSessions(); + Status StartStep(const BuildGraphOptions& opts, int64* count, ReffedClientGraph** graph, bool is_partial); void ClearRunsTable(std::vector* to_unref, diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc index 170c72deca..b3b05408b1 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc @@ -47,6 +47,7 @@ class GrpcRemoteWorker : public WorkerInterface { cq_(completion_queue), getstatus_(Method(GrpcWorkerMethod::kGetStatus)), createworkersession_(Method(GrpcWorkerMethod::kCreateWorkerSession)), + deleteworkersession_(Method(GrpcWorkerMethod::kDeleteWorkerSession)), registergraph_(Method(GrpcWorkerMethod::kRegisterGraph)), deregistergraph_(Method(GrpcWorkerMethod::kDeregisterGraph)), rungraph_(Method(GrpcWorkerMethod::kRunGraph)), @@ -71,6 +72,12 @@ class GrpcRemoteWorker : public WorkerInterface { IssueRequest(request, response, createworkersession_, std::move(done)); } + void DeleteWorkerSessionAsync(const DeleteWorkerSessionRequest* request, + DeleteWorkerSessionResponse* response, + StatusCallback done) override { + IssueRequest(request, response, deleteworkersession_, std::move(done)); + } + void RegisterGraphAsync(const RegisterGraphRequest* request, RegisterGraphResponse* response, StatusCallback done) override { @@ -199,6 +206,7 @@ class GrpcRemoteWorker : public WorkerInterface { const ::grpc::string getstatus_; const ::grpc::string createworkersession_; + const ::grpc::string deleteworkersession_; const ::grpc::string registergraph_; const ::grpc::string deregistergraph_; const ::grpc::string rungraph_; diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc index 4ee5ae0901..eee93ec657 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc @@ -114,6 +114,7 @@ class GrpcWorkerService : public AsyncServiceInterface { // types. ENQUEUE_REQUEST(GetStatus, false); ENQUEUE_REQUEST(CreateWorkerSession, false); + ENQUEUE_REQUEST(DeleteWorkerSession, false); ENQUEUE_REQUEST(CleanupAll, false); ENQUEUE_REQUEST(RegisterGraph, false); ENQUEUE_REQUEST(DeregisterGraph, false); @@ -192,6 +193,16 @@ class GrpcWorkerService : public AsyncServiceInterface { ENQUEUE_REQUEST(CreateWorkerSession, false); } + void DeleteWorkerSessionHandler( + WorkerCall* + call) { + Schedule([this, call]() { + Status s = worker_->DeleteWorkerSession(&call->request, &call->response); + call->SendResponse(ToGrpcStatus(s)); + }); + ENQUEUE_REQUEST(DeleteWorkerSession, false); + } + void CleanupAllHandler( WorkerCall* call) { Schedule([this, call]() { diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc index 348c6dc98b..05a9db10d3 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc @@ -32,6 +32,8 @@ const char* GrpcWorkerMethodName(GrpcWorkerMethod id) { return "/tensorflow.WorkerService/GetStatus"; case GrpcWorkerMethod::kCreateWorkerSession: return "/tensorflow.WorkerService/CreateWorkerSession"; + case GrpcWorkerMethod::kDeleteWorkerSession: + return "/tensorflow.WorkerService/DeleteWorkerSession"; case GrpcWorkerMethod::kRegisterGraph: return "/tensorflow.WorkerService/RegisterGraph"; case GrpcWorkerMethod::kDeregisterGraph: diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h index e9862a61a3..fb23f8631f 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h @@ -110,6 +110,7 @@ namespace tensorflow { enum class GrpcWorkerMethod { kGetStatus, kCreateWorkerSession, + kDeleteWorkerSession, kRegisterGraph, kDeregisterGraph, kRunGraph, diff --git a/tensorflow/core/distributed_runtime/worker.cc b/tensorflow/core/distributed_runtime/worker.cc index fcb1830197..8bf87923ed 100644 --- a/tensorflow/core/distributed_runtime/worker.cc +++ b/tensorflow/core/distributed_runtime/worker.cc @@ -48,6 +48,13 @@ void Worker::CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request, done(s); } +void Worker::DeleteWorkerSessionAsync(const DeleteWorkerSessionRequest* request, + DeleteWorkerSessionResponse* response, + StatusCallback done) { + Status s = env_->session_mgr->DeleteSession(request->session_handle()); + done(s); +} + void Worker::RegisterGraphAsync(const RegisterGraphRequest* request, RegisterGraphResponse* response, StatusCallback done) { diff --git a/tensorflow/core/distributed_runtime/worker.h b/tensorflow/core/distributed_runtime/worker.h index 07300338c3..c62347926f 100644 --- a/tensorflow/core/distributed_runtime/worker.h +++ b/tensorflow/core/distributed_runtime/worker.h @@ -52,6 +52,10 @@ class Worker : public WorkerInterface { CreateWorkerSessionResponse* response, StatusCallback done) override; + void DeleteWorkerSessionAsync(const DeleteWorkerSessionRequest* request, + DeleteWorkerSessionResponse* response, + StatusCallback done) override; + void RegisterGraphAsync(const RegisterGraphRequest* request, RegisterGraphResponse* response, StatusCallback done) override; diff --git a/tensorflow/core/distributed_runtime/worker_interface.h b/tensorflow/core/distributed_runtime/worker_interface.h index c9db28ec67..4c58bf41a4 100644 --- a/tensorflow/core/distributed_runtime/worker_interface.h +++ b/tensorflow/core/distributed_runtime/worker_interface.h @@ -44,6 +44,10 @@ class WorkerInterface { const CreateWorkerSessionRequest* request, CreateWorkerSessionResponse* response, StatusCallback done) = 0; + virtual void DeleteWorkerSessionAsync( + const DeleteWorkerSessionRequest* request, + DeleteWorkerSessionResponse* response, StatusCallback done) = 0; + virtual void RegisterGraphAsync(const RegisterGraphRequest* request, RegisterGraphResponse* response, StatusCallback done) = 0; @@ -118,6 +122,11 @@ class WorkerInterface { return CallAndWait(&ME::CreateWorkerSessionAsync, request, response); } + Status DeleteWorkerSession(const DeleteWorkerSessionRequest* request, + DeleteWorkerSessionResponse* response) { + return CallAndWait(&ME::DeleteWorkerSessionAsync, request, response); + } + Status RegisterGraph(const RegisterGraphRequest* request, RegisterGraphResponse* response) { return CallAndWait(&ME::RegisterGraphAsync, request, response); diff --git a/tensorflow/core/protobuf/worker.proto b/tensorflow/core/protobuf/worker.proto index 34a5cff366..e7b3f36fcc 100644 --- a/tensorflow/core/protobuf/worker.proto +++ b/tensorflow/core/protobuf/worker.proto @@ -64,6 +64,22 @@ message CreateWorkerSessionRequest { message CreateWorkerSessionResponse { } +//////////////////////////////////////////////////////////////////////////////// +// +// DeleteSession method request/response messages +// +// Deletes all worker-side state associated with the given session handle. +// +//////////////////////////////////////////////////////////////////////////////// + +message DeleteWorkerSessionRequest { + // Sessions are identified by a given handle. + string session_handle = 1; +} + +message DeleteWorkerSessionResponse { +} + //////////////////////////////////////////////////////////////////////////////// // // RegisterGraph method request/response messages diff --git a/tensorflow/core/protobuf/worker_service.proto b/tensorflow/core/protobuf/worker_service.proto index 3de9e48b78..e1bfb04d7c 100644 --- a/tensorflow/core/protobuf/worker_service.proto +++ b/tensorflow/core/protobuf/worker_service.proto @@ -43,6 +43,10 @@ service WorkerService { rpc CreateWorkerSession(CreateWorkerSessionRequest) returns (CreateWorkerSessionResponse); + // See worker.proto for details. + rpc DeleteWorkerSession(DeleteWorkerSessionRequest) + returns (DeleteWorkerSessionResponse); + // See worker.proto for details. rpc RegisterGraph(RegisterGraphRequest) returns (RegisterGraphResponse); -- cgit v1.2.3