aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2017-11-15 14:07:41 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-15 14:12:16 -0800
commitfdff4048d4d0fdf7c12f927b92bb5e2fb812df12 (patch)
tree806544c485973a3f1e2d8e09eb6b9e4094e53060
parentb0bcf675a4b5d6217f3b58fd27b344f20e7bf25d (diff)
Add `WorkerService.DeleteWorkerSession` method to fix a memory leak.
The new method is the counterpart to `WorkerService.CreateWorkerSession`, and is called in all cases where worker sessions have been explicitly created (i.e. when using ClusterSpec propagation). PiperOrigin-RevId: 175877407
-rw-r--r--tensorflow/core/distributed_runtime/master_session.cc60
-rw-r--r--tensorflow/core/distributed_runtime/master_session.h4
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc8
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc11
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc2
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h1
-rw-r--r--tensorflow/core/distributed_runtime/worker.cc7
-rw-r--r--tensorflow/core/distributed_runtime/worker.h4
-rw-r--r--tensorflow/core/distributed_runtime/worker_interface.h9
-rw-r--r--tensorflow/core/protobuf/worker.proto16
-rw-r--r--tensorflow/core/protobuf/worker_service.proto4
11 files changed, 126 insertions, 0 deletions
diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc
index 5798ad09e8..91a1fa7d1e 100644
--- a/tensorflow/core/distributed_runtime/master_session.cc
+++ b/tensorflow/core/distributed_runtime/master_session.cc
@@ -1044,6 +1044,7 @@ Status MasterSession::Create(GraphDef* graph_def,
graph_def, execution_options, &execution_state_));
}
if (options.cluster_def != nullptr) {
+ should_delete_worker_sessions_ = true;
return CreateWorkerSessions(options);
}
return Status::OK();
@@ -1122,6 +1123,59 @@ Status MasterSession::CreateWorkerSessions(
return status;
}
+Status MasterSession::DeleteWorkerSessions() {
+ WorkerCacheInterface* worker_cache = get_worker_cache();
+ std::vector<string> worker_names;
+ worker_cache->ListWorkers(&worker_names);
+
+ struct WorkerGroup {
+ // The worker name. (Not owned.)
+ const string* name;
+
+ // The worker referenced by name. (Not owned.)
+ WorkerInterface* worker = nullptr;
+
+ // Request and responses used for a given worker.
+ DeleteWorkerSessionRequest request;
+ DeleteWorkerSessionResponse response;
+ Status status = Status::OK();
+ };
+ BlockingCounter done(worker_names.size());
+ std::vector<WorkerGroup> workers(worker_names.size());
+
+ // Release the workers.
+ auto cleanup = gtl::MakeCleanup([this, &workers, worker_cache] {
+ for (auto&& worker_group : workers) {
+ if (worker_group.worker != nullptr) {
+ worker_cache->ReleaseWorker(*worker_group.name, worker_group.worker);
+ }
+ }
+ });
+
+ Status status = Status::OK();
+ // Create all the workers & kick off the computations.
+ for (size_t i = 0; i < worker_names.size(); ++i) {
+ workers[i].name = &worker_names[i];
+ workers[i].worker = worker_cache_->CreateWorker(worker_names[i]);
+ workers[i].request.set_session_handle(handle_);
+ }
+
+ for (size_t i = 0; i < worker_names.size(); ++i) {
+ auto cb = [i, &workers, &done](const Status& s) {
+ workers[i].status = s;
+ done.DecrementCount();
+ };
+ workers[i].worker->DeleteWorkerSessionAsync(&workers[i].request,
+ &workers[i].response, cb);
+ }
+
+ done.Wait();
+ for (size_t i = 0; i < workers.size(); ++i) {
+ status.Update(workers[i].status);
+ }
+ return status;
+}
+
Status MasterSession::ListDevices(ListDevicesResponse* resp) const {
if (worker_cache_) {
// This is a ClusterSpec-propagated session, and thus env_->local_devices
@@ -1604,6 +1658,12 @@ Status MasterSession::Close() {
ClearRunsTable(&to_unref, &partial_run_graphs_);
}
for (ReffedClientGraph* rcg : to_unref) rcg->Unref();
+ if (should_delete_worker_sessions_) {
+ Status s = DeleteWorkerSessions();
+ if (!s.ok()) {
+ LOG(WARNING) << s;
+ }
+ }
return Status::OK();
}
diff --git a/tensorflow/core/distributed_runtime/master_session.h b/tensorflow/core/distributed_runtime/master_session.h
index eb696eb06a..4bd4e1367a 100644
--- a/tensorflow/core/distributed_runtime/master_session.h
+++ b/tensorflow/core/distributed_runtime/master_session.h
@@ -201,6 +201,10 @@ class MasterSession : public core::RefCounted {
// workers.
Status CreateWorkerSessions(const WorkerCacheFactoryOptions& server_def);
+ // TODO(b/36574172): Always use Create/DeleteWorkerSession.
+ bool should_delete_worker_sessions_ = false;
+ Status DeleteWorkerSessions();
+
Status StartStep(const BuildGraphOptions& opts, int64* count,
ReffedClientGraph** graph, bool is_partial);
void ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref,
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc
index 170c72deca..b3b05408b1 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc
@@ -47,6 +47,7 @@ class GrpcRemoteWorker : public WorkerInterface {
cq_(completion_queue),
getstatus_(Method(GrpcWorkerMethod::kGetStatus)),
createworkersession_(Method(GrpcWorkerMethod::kCreateWorkerSession)),
+ deleteworkersession_(Method(GrpcWorkerMethod::kDeleteWorkerSession)),
registergraph_(Method(GrpcWorkerMethod::kRegisterGraph)),
deregistergraph_(Method(GrpcWorkerMethod::kDeregisterGraph)),
rungraph_(Method(GrpcWorkerMethod::kRunGraph)),
@@ -71,6 +72,12 @@ class GrpcRemoteWorker : public WorkerInterface {
IssueRequest(request, response, createworkersession_, std::move(done));
}
+ void DeleteWorkerSessionAsync(const DeleteWorkerSessionRequest* request,
+ DeleteWorkerSessionResponse* response,
+ StatusCallback done) override {
+ IssueRequest(request, response, deleteworkersession_, std::move(done));
+ }
+
void RegisterGraphAsync(const RegisterGraphRequest* request,
RegisterGraphResponse* response,
StatusCallback done) override {
@@ -199,6 +206,7 @@ class GrpcRemoteWorker : public WorkerInterface {
const ::grpc::string getstatus_;
const ::grpc::string createworkersession_;
+ const ::grpc::string deleteworkersession_;
const ::grpc::string registergraph_;
const ::grpc::string deregistergraph_;
const ::grpc::string rungraph_;
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc
index 4ee5ae0901..eee93ec657 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc
@@ -114,6 +114,7 @@ class GrpcWorkerService : public AsyncServiceInterface {
// types.
ENQUEUE_REQUEST(GetStatus, false);
ENQUEUE_REQUEST(CreateWorkerSession, false);
+ ENQUEUE_REQUEST(DeleteWorkerSession, false);
ENQUEUE_REQUEST(CleanupAll, false);
ENQUEUE_REQUEST(RegisterGraph, false);
ENQUEUE_REQUEST(DeregisterGraph, false);
@@ -192,6 +193,16 @@ class GrpcWorkerService : public AsyncServiceInterface {
ENQUEUE_REQUEST(CreateWorkerSession, false);
}
+ void DeleteWorkerSessionHandler(
+ WorkerCall<DeleteWorkerSessionRequest, DeleteWorkerSessionResponse>*
+ call) {
+ Schedule([this, call]() {
+ Status s = worker_->DeleteWorkerSession(&call->request, &call->response);
+ call->SendResponse(ToGrpcStatus(s));
+ });
+ ENQUEUE_REQUEST(DeleteWorkerSession, false);
+ }
+
void CleanupAllHandler(
WorkerCall<CleanupAllRequest, CleanupAllResponse>* call) {
Schedule([this, call]() {
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc
index 348c6dc98b..05a9db10d3 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc
@@ -32,6 +32,8 @@ const char* GrpcWorkerMethodName(GrpcWorkerMethod id) {
return "/tensorflow.WorkerService/GetStatus";
case GrpcWorkerMethod::kCreateWorkerSession:
return "/tensorflow.WorkerService/CreateWorkerSession";
+ case GrpcWorkerMethod::kDeleteWorkerSession:
+ return "/tensorflow.WorkerService/DeleteWorkerSession";
case GrpcWorkerMethod::kRegisterGraph:
return "/tensorflow.WorkerService/RegisterGraph";
case GrpcWorkerMethod::kDeregisterGraph:
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h
index e9862a61a3..fb23f8631f 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h
@@ -110,6 +110,7 @@ namespace tensorflow {
enum class GrpcWorkerMethod {
kGetStatus,
kCreateWorkerSession,
+ kDeleteWorkerSession,
kRegisterGraph,
kDeregisterGraph,
kRunGraph,
diff --git a/tensorflow/core/distributed_runtime/worker.cc b/tensorflow/core/distributed_runtime/worker.cc
index fcb1830197..8bf87923ed 100644
--- a/tensorflow/core/distributed_runtime/worker.cc
+++ b/tensorflow/core/distributed_runtime/worker.cc
@@ -48,6 +48,13 @@ void Worker::CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request,
done(s);
}
+void Worker::DeleteWorkerSessionAsync(const DeleteWorkerSessionRequest* request,
+ DeleteWorkerSessionResponse* response,
+ StatusCallback done) {
+ Status s = env_->session_mgr->DeleteSession(request->session_handle());
+ done(s);
+}
+
void Worker::RegisterGraphAsync(const RegisterGraphRequest* request,
RegisterGraphResponse* response,
StatusCallback done) {
diff --git a/tensorflow/core/distributed_runtime/worker.h b/tensorflow/core/distributed_runtime/worker.h
index 07300338c3..c62347926f 100644
--- a/tensorflow/core/distributed_runtime/worker.h
+++ b/tensorflow/core/distributed_runtime/worker.h
@@ -52,6 +52,10 @@ class Worker : public WorkerInterface {
CreateWorkerSessionResponse* response,
StatusCallback done) override;
+ void DeleteWorkerSessionAsync(const DeleteWorkerSessionRequest* request,
+ DeleteWorkerSessionResponse* response,
+ StatusCallback done) override;
+
void RegisterGraphAsync(const RegisterGraphRequest* request,
RegisterGraphResponse* response,
StatusCallback done) override;
diff --git a/tensorflow/core/distributed_runtime/worker_interface.h b/tensorflow/core/distributed_runtime/worker_interface.h
index c9db28ec67..4c58bf41a4 100644
--- a/tensorflow/core/distributed_runtime/worker_interface.h
+++ b/tensorflow/core/distributed_runtime/worker_interface.h
@@ -44,6 +44,10 @@ class WorkerInterface {
const CreateWorkerSessionRequest* request,
CreateWorkerSessionResponse* response, StatusCallback done) = 0;
+ virtual void DeleteWorkerSessionAsync(
+ const DeleteWorkerSessionRequest* request,
+ DeleteWorkerSessionResponse* response, StatusCallback done) = 0;
+
virtual void RegisterGraphAsync(const RegisterGraphRequest* request,
RegisterGraphResponse* response,
StatusCallback done) = 0;
@@ -118,6 +122,11 @@ class WorkerInterface {
return CallAndWait(&ME::CreateWorkerSessionAsync, request, response);
}
+ Status DeleteWorkerSession(const DeleteWorkerSessionRequest* request,
+ DeleteWorkerSessionResponse* response) {
+ return CallAndWait(&ME::DeleteWorkerSessionAsync, request, response);
+ }
+
Status RegisterGraph(const RegisterGraphRequest* request,
RegisterGraphResponse* response) {
return CallAndWait(&ME::RegisterGraphAsync, request, response);
diff --git a/tensorflow/core/protobuf/worker.proto b/tensorflow/core/protobuf/worker.proto
index 34a5cff366..e7b3f36fcc 100644
--- a/tensorflow/core/protobuf/worker.proto
+++ b/tensorflow/core/protobuf/worker.proto
@@ -66,6 +66,22 @@ message CreateWorkerSessionResponse {
////////////////////////////////////////////////////////////////////////////////
//
+// DeleteSession method request/response messages
+//
+// Deletes all worker-side state associated with the given session handle.
+//
+////////////////////////////////////////////////////////////////////////////////
+
+message DeleteWorkerSessionRequest {
+ // Sessions are identified by a given handle.
+ string session_handle = 1;
+}
+
+message DeleteWorkerSessionResponse {
+}
+
+////////////////////////////////////////////////////////////////////////////////
+//
// RegisterGraph method request/response messages
//
// For each session, after the master placed every node on a device,
diff --git a/tensorflow/core/protobuf/worker_service.proto b/tensorflow/core/protobuf/worker_service.proto
index 3de9e48b78..e1bfb04d7c 100644
--- a/tensorflow/core/protobuf/worker_service.proto
+++ b/tensorflow/core/protobuf/worker_service.proto
@@ -44,6 +44,10 @@ service WorkerService {
returns (CreateWorkerSessionResponse);
// See worker.proto for details.
+ rpc DeleteWorkerSession(DeleteWorkerSessionRequest)
+ returns (DeleteWorkerSessionResponse);
+
+ // See worker.proto for details.
rpc RegisterGraph(RegisterGraphRequest) returns (RegisterGraphResponse);
// See worker.proto for details.