aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/distributed_runtime
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-04-18 15:27:05 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-18 15:30:46 -0700
commit80f60ea37ed77b3dbe1d983f101a5efba2fd4f2e (patch)
treec9a5679baed7ff915eb78ac853f02c62224e4b39 /tensorflow/core/distributed_runtime
parente662c3fcfcd03fd091b032a5a33971428f4cdb89 (diff)
Never use the LegacySession when a Master explicitly calls CreateWorkerSession.
Previously, if the session handle was unrecognized by the worker, it would default to using the LegacySession. This prevents us from noticing that a server has been restarted. To address the problem in a backwards-compatible way, we add a bit to each session-handle-carrying worker request, indicating whether the master believes that CreateWorkerSession has been called. If this bit is set and the handle is unrecognized, the worker will raise an AbortedError, which can be caught by high-level frameworks such as `tf.estimator`. Note that CreateWorkerSession is not yet used by default, and a follow-up change will add that. PiperOrigin-RevId: 193427057
Diffstat (limited to 'tensorflow/core/distributed_runtime')
-rw-r--r--tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc2
-rw-r--r--tensorflow/core/distributed_runtime/cluster_function_library_runtime.h7
-rw-r--r--tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc2
-rw-r--r--tensorflow/core/distributed_runtime/master_session.cc3
-rw-r--r--tensorflow/core/distributed_runtime/message_wrappers.cc23
-rw-r--r--tensorflow/core/distributed_runtime/message_wrappers.h10
-rw-r--r--tensorflow/core/distributed_runtime/session_mgr.cc24
-rw-r--r--tensorflow/core/distributed_runtime/session_mgr.h8
-rw-r--r--tensorflow/core/distributed_runtime/session_mgr_test.cc34
-rw-r--r--tensorflow/core/distributed_runtime/worker.cc67
-rw-r--r--tensorflow/core/distributed_runtime/worker_session.cc3
11 files changed, 143 insertions, 40 deletions
diff --git a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc
index 000a03da5d..6edc2ec5ed 100644
--- a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc
+++ b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc
@@ -145,6 +145,7 @@ Status ClusterFunctionLibraryRuntime::Instantiate(
RegisterGraphRequest req;
req.set_session_handle(worker_session_->session_name);
+ req.set_create_worker_session_called(create_worker_session_called_);
*req.mutable_graph_def() = gdef;
req.mutable_graph_options()
->mutable_optimizer_options()
@@ -182,6 +183,7 @@ void ClusterFunctionLibraryRuntime::Run(
RunGraphRequest* req = new RunGraphRequest;
req->set_session_handle(worker_session_->session_name);
+ req->set_create_worker_session_called(create_worker_session_called_);
req->set_graph_handle(function_data->graph_handle);
// Borrowed from master_session.cc
const uint64 step_id = (random::New64() & ((1uLL << 56) - 1)) | (1uLL << 56);
diff --git a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h
index d3ca350e36..1ea0a3ad51 100644
--- a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h
+++ b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h
@@ -27,8 +27,10 @@ struct WorkerSession;
// functions across processes by making RPCs.
class ClusterFunctionLibraryRuntime : public DistributedFunctionLibraryRuntime {
public:
- ClusterFunctionLibraryRuntime(WorkerSession* worker_session)
- : worker_session_(worker_session) {}
+ ClusterFunctionLibraryRuntime(WorkerSession* worker_session,
+ bool create_worker_session_called)
+ : worker_session_(worker_session),
+ create_worker_session_called_(create_worker_session_called) {}
~ClusterFunctionLibraryRuntime() override;
@@ -51,6 +53,7 @@ class ClusterFunctionLibraryRuntime : public DistributedFunctionLibraryRuntime {
mutable mutex mu_;
WorkerSession* const worker_session_ = nullptr; // not owned.
+ const bool create_worker_session_called_;
struct FunctionData {
const string graph_handle;
diff --git a/tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc b/tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc
index 1810996ab8..6f96d7cb06 100644
--- a/tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc
+++ b/tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc
@@ -44,7 +44,7 @@ class ClusterFunctionLibraryRuntimeTest : public ::testing::Test {
std::unique_ptr<GraphMgr>()));
cluster_flr_.reset(
- new ClusterFunctionLibraryRuntime(worker_session_.get()));
+ new ClusterFunctionLibraryRuntime(worker_session_.get(), true));
}
Status ConstructFunctionGraphHelper(
diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc
index e0a5bb4c53..08020f0266 100644
--- a/tensorflow/core/distributed_runtime/master_session.cc
+++ b/tensorflow/core/distributed_runtime/master_session.cc
@@ -431,6 +431,7 @@ Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
const Part& part = partitions_[i];
Call* c = &calls[i];
c->req.set_session_handle(session_handle_);
+ c->req.set_create_worker_session_called(!should_deregister_);
c->req.mutable_graph_def()->Swap(&graph_partitions[part.name]);
*c->req.mutable_graph_options() = session_opts_.config.graph_options();
*c->req.mutable_debug_options() =
@@ -587,6 +588,7 @@ Status MasterSession::ReffedClientGraph::RunPartitionsHelper(
c->req->set_is_last_partial_run(is_last_partial_run);
}
c->req->set_session_handle(session_handle_);
+ c->req->set_create_worker_session_called(!should_deregister_);
c->req->set_graph_handle(part.graph_handle);
c->req->set_step_id(step_id);
*c->req->mutable_exec_opts() = exec_opts;
@@ -1003,6 +1005,7 @@ void MasterSession::ReffedClientGraph::DeregisterPartitions() {
if (!part.graph_handle.empty()) {
Call* c = new Call;
c->req.set_session_handle(session_handle_);
+ c->req.set_create_worker_session_called(!should_deregister_);
c->req.set_graph_handle(part.graph_handle);
// NOTE(mrry): We must capture `worker_cache_` since `this`
// could be deleted before the callback is called.
diff --git a/tensorflow/core/distributed_runtime/message_wrappers.cc b/tensorflow/core/distributed_runtime/message_wrappers.cc
index 18668b44d3..40bf564cab 100644
--- a/tensorflow/core/distributed_runtime/message_wrappers.cc
+++ b/tensorflow/core/distributed_runtime/message_wrappers.cc
@@ -282,10 +282,18 @@ const string& InMemoryRunGraphRequest::session_handle() const {
return session_handle_;
}
+bool InMemoryRunGraphRequest::create_worker_session_called() const {
+ return create_worker_session_called_;
+}
+
void InMemoryRunGraphRequest::set_session_handle(const string& handle) {
session_handle_ = handle;
}
+void InMemoryRunGraphRequest::set_create_worker_session_called(bool called) {
+ create_worker_session_called_ = called;
+}
+
const string& InMemoryRunGraphRequest::graph_handle() const {
return graph_handle_;
}
@@ -378,6 +386,8 @@ const RunGraphRequest& InMemoryRunGraphRequest::ToProto() const {
if (!proto_version_) {
proto_version_.reset(new RunGraphRequest);
proto_version_->set_session_handle(session_handle());
+ proto_version_->set_create_worker_session_called(
+ create_worker_session_called());
proto_version_->set_graph_handle(graph_handle());
proto_version_->set_step_id(step_id());
*proto_version_->mutable_exec_opts() = exec_opts();
@@ -403,6 +413,15 @@ void MutableProtoRunGraphRequest::set_session_handle(const string& handle) {
request_.set_session_handle(handle);
}
+bool MutableProtoRunGraphRequest::create_worker_session_called() const {
+ return request_.create_worker_session_called();
+}
+
+void MutableProtoRunGraphRequest::set_create_worker_session_called(
+ bool called) {
+ request_.set_create_worker_session_called(called);
+}
+
const string& MutableProtoRunGraphRequest::graph_handle() const {
return request_.graph_handle();
}
@@ -514,6 +533,10 @@ const string& ProtoRunGraphRequest::session_handle() const {
return request_->session_handle();
}
+bool ProtoRunGraphRequest::create_worker_session_called() const {
+ return request_->create_worker_session_called();
+}
+
const string& ProtoRunGraphRequest::graph_handle() const {
return request_->graph_handle();
}
diff --git a/tensorflow/core/distributed_runtime/message_wrappers.h b/tensorflow/core/distributed_runtime/message_wrappers.h
index 1f7cdb98a4..92c5668e3a 100644
--- a/tensorflow/core/distributed_runtime/message_wrappers.h
+++ b/tensorflow/core/distributed_runtime/message_wrappers.h
@@ -246,6 +246,9 @@ class RunGraphRequestWrapper {
// namespace is used.
virtual const string& session_handle() const = 0;
+ // Set to true if `CreateWorkerSession` was called for `session_handle`.
+ virtual bool create_worker_session_called() const = 0;
+
// REQUIRED: graph_handle must be returned by a RegisterGraph call
// to the same WorkerService.
virtual const string& graph_handle() const = 0;
@@ -293,6 +296,7 @@ class RunGraphRequestWrapper {
class MutableRunGraphRequestWrapper : public RunGraphRequestWrapper {
public:
virtual void set_session_handle(const string& handle) = 0;
+ virtual void set_create_worker_session_called(bool called) = 0;
virtual void set_graph_handle(const string& handle) = 0;
virtual void set_step_id(int64 step_id) = 0;
virtual ExecutorOpts* mutable_exec_opts() = 0;
@@ -317,6 +321,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
// RunGraphRequestWrapper methods.
const string& session_handle() const override;
const string& graph_handle() const override;
+ bool create_worker_session_called() const override;
int64 step_id() const override;
const ExecutorOpts& exec_opts() const override;
size_t num_sends() const override;
@@ -331,6 +336,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
// MutableRunGraphRequestWrapper methods.
void set_session_handle(const string& handle) override;
+ void set_create_worker_session_called(bool called) override;
void set_graph_handle(const string& handle) override;
void set_step_id(int64 step_id) override;
ExecutorOpts* mutable_exec_opts() override;
@@ -347,6 +353,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
private:
string session_handle_;
+ bool create_worker_session_called_;
string graph_handle_;
int64 step_id_;
ExecutorOpts exec_opts_;
@@ -370,6 +377,7 @@ class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper {
public:
// RunGraphRequestWrapper methods.
const string& session_handle() const override;
+ bool create_worker_session_called() const override;
const string& graph_handle() const override;
int64 step_id() const override;
const ExecutorOpts& exec_opts() const override;
@@ -385,6 +393,7 @@ class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper {
// MutableRunGraphRequestWrapper methods.
void set_session_handle(const string& handle) override;
+ void set_create_worker_session_called(bool called) override;
void set_graph_handle(const string& handle) override;
void set_step_id(int64 step_id) override;
ExecutorOpts* mutable_exec_opts() override;
@@ -409,6 +418,7 @@ class ProtoRunGraphRequest : public RunGraphRequestWrapper {
// RunGraphRequestWrapper methods.
const string& session_handle() const override;
+ bool create_worker_session_called() const override;
const string& graph_handle() const override;
int64 step_id() const override;
const ExecutorOpts& exec_opts() const override;
diff --git a/tensorflow/core/distributed_runtime/session_mgr.cc b/tensorflow/core/distributed_runtime/session_mgr.cc
index 51b9547f53..e51d63cf2b 100644
--- a/tensorflow/core/distributed_runtime/session_mgr.cc
+++ b/tensorflow/core/distributed_runtime/session_mgr.cc
@@ -98,20 +98,26 @@ Status SessionMgr::DeleteSession(const string& session) {
return Status::OK();
}
-std::shared_ptr<WorkerSession> SessionMgr::WorkerSessionForSessionUnlocked(
- const string& session) {
- auto it = sessions_.find(session);
- if (it == sessions_.end()) {
- return legacy_session_;
+Status SessionMgr::WorkerSessionForSessionLocked(
+ const string& session_handle, std::shared_ptr<WorkerSession>* out_session) {
+ if (session_handle.empty()) {
+ *out_session = legacy_session_;
} else {
- return it->second;
+ auto it = sessions_.find(session_handle);
+ if (it == sessions_.end()) {
+ return errors::Aborted("Session handle is not found: ", session_handle,
+ ". Possibly this worker just restarted.");
+ } else {
+ *out_session = it->second;
+ }
}
+ return Status::OK();
}
-std::shared_ptr<WorkerSession> SessionMgr::WorkerSessionForSession(
- const string& session) {
+Status SessionMgr::WorkerSessionForSession(
+ const string& session_handle, std::shared_ptr<WorkerSession>* out_session) {
mutex_lock l(mu_);
- return WorkerSessionForSessionUnlocked(session);
+ return WorkerSessionForSessionLocked(session_handle, out_session);
}
std::shared_ptr<WorkerSession> SessionMgr::LegacySession() {
diff --git a/tensorflow/core/distributed_runtime/session_mgr.h b/tensorflow/core/distributed_runtime/session_mgr.h
index 4c9702d522..0a10fe240f 100644
--- a/tensorflow/core/distributed_runtime/session_mgr.h
+++ b/tensorflow/core/distributed_runtime/session_mgr.h
@@ -50,7 +50,8 @@ class SessionMgr {
bool isolate_session_state);
// Locates the worker session for a given session handle
- std::shared_ptr<WorkerSession> WorkerSessionForSession(const string& session);
+ Status WorkerSessionForSession(const string& session_handle,
+ std::shared_ptr<WorkerSession>* out_session);
std::shared_ptr<WorkerSession> LegacySession();
Status DeleteSession(const string& session);
@@ -86,8 +87,9 @@ class SessionMgr {
const WorkerCacheFactory worker_cache_factory_;
- std::shared_ptr<WorkerSession> WorkerSessionForSessionUnlocked(
- const string& session) EXCLUSIVE_LOCKS_REQUIRED(mu_);
+ Status WorkerSessionForSessionLocked(
+ const string& session_handle, std::shared_ptr<WorkerSession>* out_session)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
mutex mu_;
// A map from session identifier to internal session structure.
diff --git a/tensorflow/core/distributed_runtime/session_mgr_test.cc b/tensorflow/core/distributed_runtime/session_mgr_test.cc
index 4d028f7f4a..858e636e08 100644
--- a/tensorflow/core/distributed_runtime/session_mgr_test.cc
+++ b/tensorflow/core/distributed_runtime/session_mgr_test.cc
@@ -46,8 +46,8 @@ class SessionMgrTest : public ::testing::Test {
: device_(FakeDevice::MakeCPU(
"/job:mnist/replica:0/task:0/device:fakecpu:0")),
mgr_(&env_, "/job:mnist/replica:0/task:0",
- std::unique_ptr<WorkerCacheInterface>(), factory_),
- legacy_session_(mgr_.WorkerSessionForSession("novel_session_id")) {
+ std::unique_ptr<WorkerCacheInterface>(), factory_) {
+ TF_CHECK_OK(mgr_.WorkerSessionForSession("", &legacy_session_));
env_.local_devices = {device_.get()};
}
@@ -69,7 +69,8 @@ TEST_F(SessionMgrTest, CreateSessionSimple) {
string session_handle = "test_session_handle";
TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def, true));
- auto session = mgr_.WorkerSessionForSession(session_handle);
+ std::shared_ptr<WorkerSession> session;
+ TF_EXPECT_OK(mgr_.WorkerSessionForSession(session_handle, &session));
EXPECT_NE(nullptr, session) << "Session for " << session_handle << "was null";
EXPECT_NE(mgr_.LegacySession(), session);
TF_EXPECT_OK(mgr_.DeleteSession(session_handle));
@@ -81,22 +82,26 @@ TEST_F(SessionMgrTest, CreateSessionIsolateSessionState) {
server_def.set_task_index(3);
TF_EXPECT_OK(mgr_.CreateSession("handle_1", server_def, false));
- auto session_1 = mgr_.WorkerSessionForSession("handle_1");
+ std::shared_ptr<WorkerSession> session_1;
+ TF_EXPECT_OK(mgr_.WorkerSessionForSession("handle_1", &session_1));
std::vector<Device*> devices_1 = session_1->device_mgr->ListDevices();
EXPECT_EQ(1, devices_1.size());
TF_EXPECT_OK(mgr_.CreateSession("handle_2", server_def, false));
- auto session_2 = mgr_.WorkerSessionForSession("handle_2");
+ std::shared_ptr<WorkerSession> session_2;
+ TF_EXPECT_OK(mgr_.WorkerSessionForSession("handle_2", &session_2));
std::vector<Device*> devices_2 = session_2->device_mgr->ListDevices();
EXPECT_EQ(1, devices_2.size());
TF_EXPECT_OK(mgr_.CreateSession("handle_3", server_def, true));
- auto session_3 = mgr_.WorkerSessionForSession("handle_3");
+ std::shared_ptr<WorkerSession> session_3;
+ TF_EXPECT_OK(mgr_.WorkerSessionForSession("handle_3", &session_3));
std::vector<Device*> devices_3 = session_3->device_mgr->ListDevices();
EXPECT_EQ(1, devices_3.size());
TF_EXPECT_OK(mgr_.CreateSession("handle_4", server_def, true));
- auto session_4 = mgr_.WorkerSessionForSession("handle_4");
+ std::shared_ptr<WorkerSession> session_4;
+ TF_EXPECT_OK(mgr_.WorkerSessionForSession("handle_4", &session_4));
std::vector<Device*> devices_4 = session_4->device_mgr->ListDevices();
EXPECT_EQ(1, devices_4.size());
@@ -109,12 +114,23 @@ TEST_F(SessionMgrTest, CreateSessionIsolateSessionState) {
TEST_F(SessionMgrTest, LegacySession) {
ServerDef server_def;
string session_handle = "";
- auto session = mgr_.WorkerSessionForSession(session_handle);
+ std::shared_ptr<WorkerSession> session;
+ TF_EXPECT_OK(mgr_.WorkerSessionForSession(session_handle, &session));
EXPECT_EQ(mgr_.LegacySession(), session);
TF_EXPECT_OK(mgr_.DeleteSession(session_handle));
}
+TEST_F(SessionMgrTest, UnknownSessionHandle) {
+ ServerDef server_def;
+ string session_handle = "unknown_session_handle";
+ std::shared_ptr<WorkerSession> session;
+ Status s = mgr_.WorkerSessionForSession(session_handle, &session);
+ EXPECT_TRUE(errors::IsAborted(s));
+ EXPECT_TRUE(
+ str_util::StrContains(s.error_message(), "Session handle is not found"));
+}
+
TEST_F(SessionMgrTest, WorkerNameFromServerDef) {
ServerDef server_def;
server_def.set_job_name("worker");
@@ -124,7 +140,7 @@ TEST_F(SessionMgrTest, WorkerNameFromServerDef) {
}
TEST_F(SessionMgrTest, DeleteLegacySession) {
- TF_EXPECT_OK(mgr_.DeleteSession("legacy_session"));
+ TF_EXPECT_OK(mgr_.DeleteSession(""));
}
} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/worker.cc b/tensorflow/core/distributed_runtime/worker.cc
index 598652fb98..6b2536c3c0 100644
--- a/tensorflow/core/distributed_runtime/worker.cc
+++ b/tensorflow/core/distributed_runtime/worker.cc
@@ -59,21 +59,37 @@ void Worker::DeleteWorkerSessionAsync(const DeleteWorkerSessionRequest* request,
void Worker::RegisterGraphAsync(const RegisterGraphRequest* request,
RegisterGraphResponse* response,
StatusCallback done) {
- auto session =
- env_->session_mgr->WorkerSessionForSession(request->session_handle());
- Status s = session->graph_mgr->Register(
- request->session_handle(), request->graph_def(), request->graph_options(),
- request->debug_options(), session->cluster_flr.get(),
- response->mutable_graph_handle());
+ std::shared_ptr<WorkerSession> session;
+ Status s;
+ if (request->create_worker_session_called()) {
+ s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
+ &session);
+ } else {
+ session = env_->session_mgr->LegacySession();
+ }
+ if (s.ok()) {
+ s = session->graph_mgr->Register(
+ request->session_handle(), request->graph_def(),
+ request->graph_options(), request->debug_options(),
+ session->cluster_flr.get(), response->mutable_graph_handle());
+ }
done(s);
}
void Worker::DeregisterGraphAsync(const DeregisterGraphRequest* request,
DeregisterGraphResponse* response,
StatusCallback done) {
- auto session =
- env_->session_mgr->WorkerSessionForSession(request->session_handle());
- Status s = session->graph_mgr->Deregister(request->graph_handle());
+ std::shared_ptr<WorkerSession> session;
+ Status s;
+ if (request->create_worker_session_called()) {
+ s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
+ &session);
+ } else {
+ session = env_->session_mgr->LegacySession();
+ }
+ if (s.ok()) {
+ s = session->graph_mgr->Deregister(request->graph_handle());
+ }
done(s);
}
@@ -135,11 +151,21 @@ void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request,
StatusCallback done) {
const int64 step_id = request->step_id();
TRACEPRINTF("RunGraph: %lld", step_id);
- auto session =
- env_->session_mgr->WorkerSessionForSession(request->session_handle());
+ std::shared_ptr<WorkerSession> session;
+ Status s;
+ if (request->create_worker_session_called()) {
+ s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
+ &session);
+ } else {
+ session = env_->session_mgr->LegacySession();
+ }
+ if (!s.ok()) {
+ done(s);
+ return;
+ }
GraphMgr::NamedTensors in;
GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
- Status s = PrepareRunGraph(request, &in, out);
+ s = PrepareRunGraph(request, &in, out);
if (!s.ok()) {
delete out;
done(s);
@@ -209,12 +235,23 @@ void Worker::DoPartialRunGraph(CallOptions* opts,
const int64 step_id = request->step_id();
const string& graph_handle = request->graph_handle();
TRACEPRINTF("PartialRunGraph: %lld", step_id);
- auto session =
- env_->session_mgr->WorkerSessionForSession(request->session_handle());
+ std::shared_ptr<WorkerSession> session;
+
+ Status s;
+ if (request->create_worker_session_called()) {
+ s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
+ &session);
+ } else {
+ session = env_->session_mgr->LegacySession();
+ }
+ if (!s.ok()) {
+ done(s);
+ return;
+ }
GraphMgr::NamedTensors in;
GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
- Status s = PrepareRunGraph(request, &in, out);
+ s = PrepareRunGraph(request, &in, out);
auto finish = [done, out, opts](const Status& s) {
opts->ClearCancelCallback();
delete out;
diff --git a/tensorflow/core/distributed_runtime/worker_session.cc b/tensorflow/core/distributed_runtime/worker_session.cc
index cb7059b36e..18886babd5 100644
--- a/tensorflow/core/distributed_runtime/worker_session.cc
+++ b/tensorflow/core/distributed_runtime/worker_session.cc
@@ -97,6 +97,7 @@ WorkerSession::WorkerSession(const string& session_name,
worker_cache(new WorkerFreeListCache(std::move(worker_cache))),
device_mgr(std::move(device_mgr)),
graph_mgr(std::move(graph_mgr)),
- cluster_flr(new ClusterFunctionLibraryRuntime(this)) {}
+ cluster_flr(
+ new ClusterFunctionLibraryRuntime(this, !session_name.empty())) {}
} // namespace tensorflow