diff options
author | Derek Murray <mrry@google.com> | 2018-04-19 10:57:55 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-19 11:00:55 -0700 |
commit | f196351cd4e21ed6c17dcf544e0fa6cfa3030b4e (patch) | |
tree | 740b087e8580ba8332987a994a88b7052abb2c4e /tensorflow/core/distributed_runtime | |
parent | 6a7779f3384e48012d3e27ae0f48d410f5174d06 (diff) |
Allow non-isolated worker sessions to borrow `WorkerEnv::device_mgr`.
Without this change, a shared resource (e.g. an Iterator) could not be
created in one session `s1`, and used in a later session `s2` after
`s1` was closed, because the iterator might indirectly capture devices
from the previous session, and use them after they are freed when the
`WorkerSession` was deleted.
The current change only affects the singleton "legacy" WorkerSession,
which is never deleted, but this is necessary to switch all sessions
to use separate WorkerSession objects.
PiperOrigin-RevId: 193541426
Diffstat (limited to 'tensorflow/core/distributed_runtime')
8 files changed, 104 insertions, 34 deletions
diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD index b07cb8cdcb..d564727da5 100644 --- a/tensorflow/core/distributed_runtime/BUILD +++ b/tensorflow/core/distributed_runtime/BUILD @@ -133,6 +133,7 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:ptr_util", "//tensorflow/core:worker_proto_cc", ], ) diff --git a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc index bafd9bfc68..5f6931e008 100644 --- a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc +++ b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc @@ -253,13 +253,13 @@ void BaseRemoteRendezvous::SameWorkerRecvDone( WorkerSession* sess = session(); Device* src_device; - Status s = sess->device_mgr->LookupDevice(parsed.src_device, &src_device); + Status s = sess->device_mgr()->LookupDevice(parsed.src_device, &src_device); if (!s.ok()) { done(s); return; } Device* dst_device; - s = sess->device_mgr->LookupDevice(parsed.dst_device, &dst_device); + s = sess->device_mgr()->LookupDevice(parsed.dst_device, &dst_device); if (!s.ok()) { done(s); return; diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc index 067dc5dff5..b8cb538503 100644 --- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc +++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc @@ -227,7 +227,7 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync( Device* dst_device; if (s.ok()) { - s = sess->device_mgr->LookupDevice(parsed.dst_device, &dst_device); + s = sess->device_mgr()->LookupDevice(parsed.dst_device, &dst_device); } if (!s.ok()) { if (rwi != nullptr) { diff --git a/tensorflow/core/distributed_runtime/session_mgr.cc b/tensorflow/core/distributed_runtime/session_mgr.cc index e51d63cf2b..357e9f8930 100644 --- a/tensorflow/core/distributed_runtime/session_mgr.cc +++ b/tensorflow/core/distributed_runtime/session_mgr.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/protobuf/cluster.pb.h" #include "tensorflow/core/protobuf/tensorflow_server.pb.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { @@ -33,11 +34,11 @@ SessionMgr::SessionMgr( WorkerCacheFactory worker_cache_factory) : worker_env_(worker_env), default_worker_cache_(std::move(default_worker_cache)), - legacy_session_(new WorkerSession( + legacy_session_(WorkerSession::CreateWithBorrowedDeviceMgr( "", default_worker_name, std::unique_ptr<WorkerCacheInterface>( new WorkerCacheWrapper(default_worker_cache_.get())), - std::unique_ptr<DeviceMgr>(worker_env->device_mgr), + worker_env->device_mgr, std::unique_ptr<GraphMgr>( new GraphMgr(worker_env, worker_env->device_mgr)))), worker_cache_factory_(std::move(worker_cache_factory)) {} @@ -71,19 +72,32 @@ Status SessionMgr::CreateSession(const string& session, CHECK(!worker_env_->local_devices.empty()) << "The WorkerEnv must have at least one device in `local_devices`."; - std::vector<Device*> renamed_devices; - for (Device* d : worker_env_->local_devices) { - renamed_devices.push_back(RenamedDevice::NewRenamedDevice( - worker_name, d, false, isolate_session_state)); - } - std::unique_ptr<DeviceMgr> device_mgr(new DeviceMgr(renamed_devices)); + std::shared_ptr<WorkerSession> worker_session; - std::unique_ptr<GraphMgr> graph_mgr( - new GraphMgr(worker_env_, device_mgr.get())); + if (isolate_session_state) { + // Create a private copy of the DeviceMgr for the WorkerSession. + std::vector<Device*> renamed_devices; + for (Device* d : worker_env_->local_devices) { + renamed_devices.push_back(RenamedDevice::NewRenamedDevice( + worker_name, d, false, isolate_session_state)); + } - std::shared_ptr<WorkerSession> worker_session(new WorkerSession( - session, worker_name, std::unique_ptr<WorkerCacheInterface>(worker_cache), - std::move(device_mgr), std::move(graph_mgr))); + auto device_mgr = MakeUnique<DeviceMgr>(renamed_devices); + auto graph_mgr = MakeUnique<GraphMgr>(worker_env_, device_mgr.get()); + worker_session.reset( + new WorkerSession(session, worker_name, + std::unique_ptr<WorkerCacheInterface>(worker_cache), + std::move(device_mgr), std::move(graph_mgr))); + } else { + // Borrown the WorkerEnv's DeviceMgr for the WorkerSession, so + // that resources using it can use its devices after the + // WorkerSession has been deleted. + auto graph_mgr = MakeUnique<GraphMgr>(worker_env_, worker_env_->device_mgr); + worker_session = WorkerSession::CreateWithBorrowedDeviceMgr( + session, worker_name, + std::unique_ptr<WorkerCacheInterface>(worker_cache), + worker_env_->device_mgr, std::move(graph_mgr)); + } sessions_.insert(std::make_pair(session, std::move(worker_session))); return Status::OK(); diff --git a/tensorflow/core/distributed_runtime/session_mgr.h b/tensorflow/core/distributed_runtime/session_mgr.h index 0a10fe240f..04d1d61409 100644 --- a/tensorflow/core/distributed_runtime/session_mgr.h +++ b/tensorflow/core/distributed_runtime/session_mgr.h @@ -65,7 +65,7 @@ class SessionMgr { void ClearLogs(); private: - const WorkerEnv* const worker_env_; // Not owned. + WorkerEnv* const worker_env_; // Not owned. // A note about destruction: // We must delete graph_mgr before device_mgr, due to shared diff --git a/tensorflow/core/distributed_runtime/session_mgr_test.cc b/tensorflow/core/distributed_runtime/session_mgr_test.cc index 858e636e08..0da333833a 100644 --- a/tensorflow/core/distributed_runtime/session_mgr_test.cc +++ b/tensorflow/core/distributed_runtime/session_mgr_test.cc @@ -43,15 +43,17 @@ class FakeDevice : public Device { class SessionMgrTest : public ::testing::Test { protected: SessionMgrTest() - : device_(FakeDevice::MakeCPU( - "/job:mnist/replica:0/task:0/device:fakecpu:0")), - mgr_(&env_, "/job:mnist/replica:0/task:0", + : mgr_(&env_, "/job:mnist/replica:0/task:0", std::unique_ptr<WorkerCacheInterface>(), factory_) { - TF_CHECK_OK(mgr_.WorkerSessionForSession("", &legacy_session_)); - env_.local_devices = {device_.get()}; + Device* device = + FakeDevice::MakeCPU("/job:mnist/replica:0/task:0/device:fakecpu:0") + .release(); + env_.local_devices = {device}; + device_mgr_.reset(new DeviceMgr(env_.local_devices)); + env_.device_mgr = device_mgr_.get(); } - std::unique_ptr<Device> device_; + std::unique_ptr<DeviceMgr> device_mgr_; WorkerEnv env_; SessionMgr::WorkerCacheFactory factory_ = [](const ServerDef& server_def, WorkerCacheInterface** worker_cache) { @@ -59,7 +61,6 @@ class SessionMgrTest : public ::testing::Test { return Status::OK(); }; SessionMgr mgr_; - std::shared_ptr<WorkerSession> legacy_session_; }; TEST_F(SessionMgrTest, CreateSessionSimple) { @@ -84,25 +85,25 @@ TEST_F(SessionMgrTest, CreateSessionIsolateSessionState) { TF_EXPECT_OK(mgr_.CreateSession("handle_1", server_def, false)); std::shared_ptr<WorkerSession> session_1; TF_EXPECT_OK(mgr_.WorkerSessionForSession("handle_1", &session_1)); - std::vector<Device*> devices_1 = session_1->device_mgr->ListDevices(); + std::vector<Device*> devices_1 = session_1->device_mgr()->ListDevices(); EXPECT_EQ(1, devices_1.size()); TF_EXPECT_OK(mgr_.CreateSession("handle_2", server_def, false)); std::shared_ptr<WorkerSession> session_2; TF_EXPECT_OK(mgr_.WorkerSessionForSession("handle_2", &session_2)); - std::vector<Device*> devices_2 = session_2->device_mgr->ListDevices(); + std::vector<Device*> devices_2 = session_2->device_mgr()->ListDevices(); EXPECT_EQ(1, devices_2.size()); TF_EXPECT_OK(mgr_.CreateSession("handle_3", server_def, true)); std::shared_ptr<WorkerSession> session_3; TF_EXPECT_OK(mgr_.WorkerSessionForSession("handle_3", &session_3)); - std::vector<Device*> devices_3 = session_3->device_mgr->ListDevices(); + std::vector<Device*> devices_3 = session_3->device_mgr()->ListDevices(); EXPECT_EQ(1, devices_3.size()); TF_EXPECT_OK(mgr_.CreateSession("handle_4", server_def, true)); std::shared_ptr<WorkerSession> session_4; TF_EXPECT_OK(mgr_.WorkerSessionForSession("handle_4", &session_4)); - std::vector<Device*> devices_4 = session_4->device_mgr->ListDevices(); + std::vector<Device*> devices_4 = session_4->device_mgr()->ListDevices(); EXPECT_EQ(1, devices_4.size()); EXPECT_EQ(devices_1[0]->resource_manager(), devices_2[0]->resource_manager()); diff --git a/tensorflow/core/distributed_runtime/worker_session.cc b/tensorflow/core/distributed_runtime/worker_session.cc index 18886babd5..ca6dc1b1de 100644 --- a/tensorflow/core/distributed_runtime/worker_session.cc +++ b/tensorflow/core/distributed_runtime/worker_session.cc @@ -95,9 +95,43 @@ WorkerSession::WorkerSession(const string& session_name, : session_name(session_name), worker_name(worker_name), worker_cache(new WorkerFreeListCache(std::move(worker_cache))), - device_mgr(std::move(device_mgr)), graph_mgr(std::move(graph_mgr)), cluster_flr( - new ClusterFunctionLibraryRuntime(this, !session_name.empty())) {} + new ClusterFunctionLibraryRuntime(this, !session_name.empty())), + device_mgr_(std::move(device_mgr)), + borrowed_device_mgr_(nullptr) {} + +/* static */ +std::shared_ptr<WorkerSession> WorkerSession::CreateWithBorrowedDeviceMgr( + const string& session_name, const string& worker_name, + std::unique_ptr<WorkerCacheInterface> worker_cache, + DeviceMgr* borrowed_device_mgr, std::unique_ptr<GraphMgr> graph_mgr) { + return std::shared_ptr<WorkerSession>( + new WorkerSession(session_name, worker_name, std::move(worker_cache), + borrowed_device_mgr, std::move(graph_mgr))); +} + +WorkerSession::WorkerSession(const string& session_name, + const string& worker_name, + std::unique_ptr<WorkerCacheInterface> worker_cache, + DeviceMgr* borrowed_device_mgr, + std::unique_ptr<GraphMgr> graph_mgr) + : session_name(session_name), + worker_name(worker_name), + worker_cache(new WorkerFreeListCache(std::move(worker_cache))), + graph_mgr(std::move(graph_mgr)), + cluster_flr( + new ClusterFunctionLibraryRuntime(this, !session_name.empty())), + device_mgr_(nullptr), + borrowed_device_mgr_(borrowed_device_mgr) {} + +WorkerSession::~WorkerSession() { + if (graph_mgr) { + Status s = graph_mgr->DeregisterAll(); + if (!s.ok()) { + LOG(WARNING) << "Error during worker session deletion: " << s; + } + } +} } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/worker_session.h b/tensorflow/core/distributed_runtime/worker_session.h index 0fd19ac27f..f1faf49364 100644 --- a/tensorflow/core/distributed_runtime/worker_session.h +++ b/tensorflow/core/distributed_runtime/worker_session.h @@ -40,10 +40,14 @@ struct WorkerSession { // Object from which WorkerInterface instances can be obtained. const std::unique_ptr<WorkerCacheInterface> worker_cache; - // Collection of local devices. These devices are typically RenamedDevices - // in all except the SessionMgr.legacy_session_. legacy_session_.device_mgr - // == worker_env_.device_mgr, which holds the true devices. - const std::unique_ptr<DeviceMgr> device_mgr; + // Collection of local devices. These devices are typically + // RenamedDevices in all except the SessionMgr.legacy_session_ and + // sessions created with `isolate_session_state == false`. In the + // those cases, this method returns a pointer to a borrowed + // DeviceMgr (typically the `worker_env.device_mgr`). + DeviceMgr* device_mgr() { + return device_mgr_ ? device_mgr_.get() : borrowed_device_mgr_; + } // graph_mgr keeps track of the registered graphs of this session. // @@ -57,6 +61,22 @@ struct WorkerSession { std::unique_ptr<WorkerCacheInterface> worker_cache, std::unique_ptr<DeviceMgr> device_mgr, std::unique_ptr<GraphMgr> graph_mgr); + + static std::shared_ptr<WorkerSession> CreateWithBorrowedDeviceMgr( + const string& session_name, const string& worker_name, + std::unique_ptr<WorkerCacheInterface> worker_cache, + DeviceMgr* borrowed_device_mgr, std::unique_ptr<GraphMgr> graph_mgr); + + ~WorkerSession(); + + private: + WorkerSession(const string& session_name, const string& worker_name, + std::unique_ptr<WorkerCacheInterface> worker_cache, + DeviceMgr* borrowed_device_mgr, + std::unique_ptr<GraphMgr> graph_mgr); + + const std::unique_ptr<DeviceMgr> device_mgr_; + DeviceMgr* const borrowed_device_mgr_; // Not owned. }; } // namespace tensorflow |