aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/distributed_runtime
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-04-19 10:57:55 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-19 11:00:55 -0700
commitf196351cd4e21ed6c17dcf544e0fa6cfa3030b4e (patch)
tree740b087e8580ba8332987a994a88b7052abb2c4e /tensorflow/core/distributed_runtime
parent6a7779f3384e48012d3e27ae0f48d410f5174d06 (diff)
Allow non-isolated worker sessions to borrow `WorkerEnv::device_mgr`.
Without this change, a shared resource (e.g. an Iterator) could not be created in one session `s1`, and used in a later session `s2` after `s1` was closed, because the iterator might indirectly capture devices from the previous session, and use them after they are freed when the `WorkerSession` was deleted. The current change only affects the singleton "legacy" WorkerSession, which is never deleted, but this is necessary to switch all sessions to use separate WorkerSession objects. PiperOrigin-RevId: 193541426
Diffstat (limited to 'tensorflow/core/distributed_runtime')
-rw-r--r--tensorflow/core/distributed_runtime/BUILD1
-rw-r--r--tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc4
-rw-r--r--tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc2
-rw-r--r--tensorflow/core/distributed_runtime/session_mgr.cc40
-rw-r--r--tensorflow/core/distributed_runtime/session_mgr.h2
-rw-r--r--tensorflow/core/distributed_runtime/session_mgr_test.cc23
-rw-r--r--tensorflow/core/distributed_runtime/worker_session.cc38
-rw-r--r--tensorflow/core/distributed_runtime/worker_session.h28
8 files changed, 104 insertions, 34 deletions
diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD
index b07cb8cdcb..d564727da5 100644
--- a/tensorflow/core/distributed_runtime/BUILD
+++ b/tensorflow/core/distributed_runtime/BUILD
@@ -133,6 +133,7 @@ cc_library(
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:ptr_util",
"//tensorflow/core:worker_proto_cc",
],
)
diff --git a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc
index bafd9bfc68..5f6931e008 100644
--- a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc
+++ b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc
@@ -253,13 +253,13 @@ void BaseRemoteRendezvous::SameWorkerRecvDone(
WorkerSession* sess = session();
Device* src_device;
- Status s = sess->device_mgr->LookupDevice(parsed.src_device, &src_device);
+ Status s = sess->device_mgr()->LookupDevice(parsed.src_device, &src_device);
if (!s.ok()) {
done(s);
return;
}
Device* dst_device;
- s = sess->device_mgr->LookupDevice(parsed.dst_device, &dst_device);
+ s = sess->device_mgr()->LookupDevice(parsed.dst_device, &dst_device);
if (!s.ok()) {
done(s);
return;
diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc
index 067dc5dff5..b8cb538503 100644
--- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc
+++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc
@@ -227,7 +227,7 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync(
Device* dst_device;
if (s.ok()) {
- s = sess->device_mgr->LookupDevice(parsed.dst_device, &dst_device);
+ s = sess->device_mgr()->LookupDevice(parsed.dst_device, &dst_device);
}
if (!s.ok()) {
if (rwi != nullptr) {
diff --git a/tensorflow/core/distributed_runtime/session_mgr.cc b/tensorflow/core/distributed_runtime/session_mgr.cc
index e51d63cf2b..357e9f8930 100644
--- a/tensorflow/core/distributed_runtime/session_mgr.cc
+++ b/tensorflow/core/distributed_runtime/session_mgr.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/protobuf/cluster.pb.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
@@ -33,11 +34,11 @@ SessionMgr::SessionMgr(
WorkerCacheFactory worker_cache_factory)
: worker_env_(worker_env),
default_worker_cache_(std::move(default_worker_cache)),
- legacy_session_(new WorkerSession(
+ legacy_session_(WorkerSession::CreateWithBorrowedDeviceMgr(
"", default_worker_name,
std::unique_ptr<WorkerCacheInterface>(
new WorkerCacheWrapper(default_worker_cache_.get())),
- std::unique_ptr<DeviceMgr>(worker_env->device_mgr),
+ worker_env->device_mgr,
std::unique_ptr<GraphMgr>(
new GraphMgr(worker_env, worker_env->device_mgr)))),
worker_cache_factory_(std::move(worker_cache_factory)) {}
@@ -71,19 +72,32 @@ Status SessionMgr::CreateSession(const string& session,
CHECK(!worker_env_->local_devices.empty())
<< "The WorkerEnv must have at least one device in `local_devices`.";
- std::vector<Device*> renamed_devices;
- for (Device* d : worker_env_->local_devices) {
- renamed_devices.push_back(RenamedDevice::NewRenamedDevice(
- worker_name, d, false, isolate_session_state));
- }
- std::unique_ptr<DeviceMgr> device_mgr(new DeviceMgr(renamed_devices));
+ std::shared_ptr<WorkerSession> worker_session;
- std::unique_ptr<GraphMgr> graph_mgr(
- new GraphMgr(worker_env_, device_mgr.get()));
+ if (isolate_session_state) {
+ // Create a private copy of the DeviceMgr for the WorkerSession.
+ std::vector<Device*> renamed_devices;
+ for (Device* d : worker_env_->local_devices) {
+ renamed_devices.push_back(RenamedDevice::NewRenamedDevice(
+ worker_name, d, false, isolate_session_state));
+ }
- std::shared_ptr<WorkerSession> worker_session(new WorkerSession(
- session, worker_name, std::unique_ptr<WorkerCacheInterface>(worker_cache),
- std::move(device_mgr), std::move(graph_mgr)));
+ auto device_mgr = MakeUnique<DeviceMgr>(renamed_devices);
+ auto graph_mgr = MakeUnique<GraphMgr>(worker_env_, device_mgr.get());
+ worker_session.reset(
+ new WorkerSession(session, worker_name,
+ std::unique_ptr<WorkerCacheInterface>(worker_cache),
+ std::move(device_mgr), std::move(graph_mgr)));
+ } else {
+ // Borrown the WorkerEnv's DeviceMgr for the WorkerSession, so
+ // that resources using it can use its devices after the
+ // WorkerSession has been deleted.
+ auto graph_mgr = MakeUnique<GraphMgr>(worker_env_, worker_env_->device_mgr);
+ worker_session = WorkerSession::CreateWithBorrowedDeviceMgr(
+ session, worker_name,
+ std::unique_ptr<WorkerCacheInterface>(worker_cache),
+ worker_env_->device_mgr, std::move(graph_mgr));
+ }
sessions_.insert(std::make_pair(session, std::move(worker_session)));
return Status::OK();
diff --git a/tensorflow/core/distributed_runtime/session_mgr.h b/tensorflow/core/distributed_runtime/session_mgr.h
index 0a10fe240f..04d1d61409 100644
--- a/tensorflow/core/distributed_runtime/session_mgr.h
+++ b/tensorflow/core/distributed_runtime/session_mgr.h
@@ -65,7 +65,7 @@ class SessionMgr {
void ClearLogs();
private:
- const WorkerEnv* const worker_env_; // Not owned.
+ WorkerEnv* const worker_env_; // Not owned.
// A note about destruction:
// We must delete graph_mgr before device_mgr, due to shared
diff --git a/tensorflow/core/distributed_runtime/session_mgr_test.cc b/tensorflow/core/distributed_runtime/session_mgr_test.cc
index 858e636e08..0da333833a 100644
--- a/tensorflow/core/distributed_runtime/session_mgr_test.cc
+++ b/tensorflow/core/distributed_runtime/session_mgr_test.cc
@@ -43,15 +43,17 @@ class FakeDevice : public Device {
class SessionMgrTest : public ::testing::Test {
protected:
SessionMgrTest()
- : device_(FakeDevice::MakeCPU(
- "/job:mnist/replica:0/task:0/device:fakecpu:0")),
- mgr_(&env_, "/job:mnist/replica:0/task:0",
+ : mgr_(&env_, "/job:mnist/replica:0/task:0",
std::unique_ptr<WorkerCacheInterface>(), factory_) {
- TF_CHECK_OK(mgr_.WorkerSessionForSession("", &legacy_session_));
- env_.local_devices = {device_.get()};
+ Device* device =
+ FakeDevice::MakeCPU("/job:mnist/replica:0/task:0/device:fakecpu:0")
+ .release();
+ env_.local_devices = {device};
+ device_mgr_.reset(new DeviceMgr(env_.local_devices));
+ env_.device_mgr = device_mgr_.get();
}
- std::unique_ptr<Device> device_;
+ std::unique_ptr<DeviceMgr> device_mgr_;
WorkerEnv env_;
SessionMgr::WorkerCacheFactory factory_ =
[](const ServerDef& server_def, WorkerCacheInterface** worker_cache) {
@@ -59,7 +61,6 @@ class SessionMgrTest : public ::testing::Test {
return Status::OK();
};
SessionMgr mgr_;
- std::shared_ptr<WorkerSession> legacy_session_;
};
TEST_F(SessionMgrTest, CreateSessionSimple) {
@@ -84,25 +85,25 @@ TEST_F(SessionMgrTest, CreateSessionIsolateSessionState) {
TF_EXPECT_OK(mgr_.CreateSession("handle_1", server_def, false));
std::shared_ptr<WorkerSession> session_1;
TF_EXPECT_OK(mgr_.WorkerSessionForSession("handle_1", &session_1));
- std::vector<Device*> devices_1 = session_1->device_mgr->ListDevices();
+ std::vector<Device*> devices_1 = session_1->device_mgr()->ListDevices();
EXPECT_EQ(1, devices_1.size());
TF_EXPECT_OK(mgr_.CreateSession("handle_2", server_def, false));
std::shared_ptr<WorkerSession> session_2;
TF_EXPECT_OK(mgr_.WorkerSessionForSession("handle_2", &session_2));
- std::vector<Device*> devices_2 = session_2->device_mgr->ListDevices();
+ std::vector<Device*> devices_2 = session_2->device_mgr()->ListDevices();
EXPECT_EQ(1, devices_2.size());
TF_EXPECT_OK(mgr_.CreateSession("handle_3", server_def, true));
std::shared_ptr<WorkerSession> session_3;
TF_EXPECT_OK(mgr_.WorkerSessionForSession("handle_3", &session_3));
- std::vector<Device*> devices_3 = session_3->device_mgr->ListDevices();
+ std::vector<Device*> devices_3 = session_3->device_mgr()->ListDevices();
EXPECT_EQ(1, devices_3.size());
TF_EXPECT_OK(mgr_.CreateSession("handle_4", server_def, true));
std::shared_ptr<WorkerSession> session_4;
TF_EXPECT_OK(mgr_.WorkerSessionForSession("handle_4", &session_4));
- std::vector<Device*> devices_4 = session_4->device_mgr->ListDevices();
+ std::vector<Device*> devices_4 = session_4->device_mgr()->ListDevices();
EXPECT_EQ(1, devices_4.size());
EXPECT_EQ(devices_1[0]->resource_manager(), devices_2[0]->resource_manager());
diff --git a/tensorflow/core/distributed_runtime/worker_session.cc b/tensorflow/core/distributed_runtime/worker_session.cc
index 18886babd5..ca6dc1b1de 100644
--- a/tensorflow/core/distributed_runtime/worker_session.cc
+++ b/tensorflow/core/distributed_runtime/worker_session.cc
@@ -95,9 +95,43 @@ WorkerSession::WorkerSession(const string& session_name,
: session_name(session_name),
worker_name(worker_name),
worker_cache(new WorkerFreeListCache(std::move(worker_cache))),
- device_mgr(std::move(device_mgr)),
graph_mgr(std::move(graph_mgr)),
cluster_flr(
- new ClusterFunctionLibraryRuntime(this, !session_name.empty())) {}
+ new ClusterFunctionLibraryRuntime(this, !session_name.empty())),
+ device_mgr_(std::move(device_mgr)),
+ borrowed_device_mgr_(nullptr) {}
+
+/* static */
+std::shared_ptr<WorkerSession> WorkerSession::CreateWithBorrowedDeviceMgr(
+ const string& session_name, const string& worker_name,
+ std::unique_ptr<WorkerCacheInterface> worker_cache,
+ DeviceMgr* borrowed_device_mgr, std::unique_ptr<GraphMgr> graph_mgr) {
+ return std::shared_ptr<WorkerSession>(
+ new WorkerSession(session_name, worker_name, std::move(worker_cache),
+ borrowed_device_mgr, std::move(graph_mgr)));
+}
+
+WorkerSession::WorkerSession(const string& session_name,
+ const string& worker_name,
+ std::unique_ptr<WorkerCacheInterface> worker_cache,
+ DeviceMgr* borrowed_device_mgr,
+ std::unique_ptr<GraphMgr> graph_mgr)
+ : session_name(session_name),
+ worker_name(worker_name),
+ worker_cache(new WorkerFreeListCache(std::move(worker_cache))),
+ graph_mgr(std::move(graph_mgr)),
+ cluster_flr(
+ new ClusterFunctionLibraryRuntime(this, !session_name.empty())),
+ device_mgr_(nullptr),
+ borrowed_device_mgr_(borrowed_device_mgr) {}
+
+WorkerSession::~WorkerSession() {
+ if (graph_mgr) {
+ Status s = graph_mgr->DeregisterAll();
+ if (!s.ok()) {
+ LOG(WARNING) << "Error during worker session deletion: " << s;
+ }
+ }
+}
} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/worker_session.h b/tensorflow/core/distributed_runtime/worker_session.h
index 0fd19ac27f..f1faf49364 100644
--- a/tensorflow/core/distributed_runtime/worker_session.h
+++ b/tensorflow/core/distributed_runtime/worker_session.h
@@ -40,10 +40,14 @@ struct WorkerSession {
// Object from which WorkerInterface instances can be obtained.
const std::unique_ptr<WorkerCacheInterface> worker_cache;
- // Collection of local devices. These devices are typically RenamedDevices
- // in all except the SessionMgr.legacy_session_. legacy_session_.device_mgr
- // == worker_env_.device_mgr, which holds the true devices.
- const std::unique_ptr<DeviceMgr> device_mgr;
+ // Collection of local devices. These devices are typically
+ // RenamedDevices in all except the SessionMgr.legacy_session_ and
+ // sessions created with `isolate_session_state == false`. In the
+ // those cases, this method returns a pointer to a borrowed
+ // DeviceMgr (typically the `worker_env.device_mgr`).
+ DeviceMgr* device_mgr() {
+ return device_mgr_ ? device_mgr_.get() : borrowed_device_mgr_;
+ }
// graph_mgr keeps track of the registered graphs of this session.
//
@@ -57,6 +61,22 @@ struct WorkerSession {
std::unique_ptr<WorkerCacheInterface> worker_cache,
std::unique_ptr<DeviceMgr> device_mgr,
std::unique_ptr<GraphMgr> graph_mgr);
+
+ static std::shared_ptr<WorkerSession> CreateWithBorrowedDeviceMgr(
+ const string& session_name, const string& worker_name,
+ std::unique_ptr<WorkerCacheInterface> worker_cache,
+ DeviceMgr* borrowed_device_mgr, std::unique_ptr<GraphMgr> graph_mgr);
+
+ ~WorkerSession();
+
+ private:
+ WorkerSession(const string& session_name, const string& worker_name,
+ std::unique_ptr<WorkerCacheInterface> worker_cache,
+ DeviceMgr* borrowed_device_mgr,
+ std::unique_ptr<GraphMgr> graph_mgr);
+
+ const std::unique_ptr<DeviceMgr> device_mgr_;
+ DeviceMgr* const borrowed_device_mgr_; // Not owned.
};
} // namespace tensorflow