aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2017-11-28 09:42:43 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-28 09:46:40 -0800
commit92d65fe6d71b5b80c130f9d9fb4474c4587f2855 (patch)
treee37f0924142f43e532637172e265b87a474797b6
parentb262375fa67d82d84e8cf9304c4c4d63411a0bc3 (diff)
Add `ConfigProto.isolate_session_state` option for the distributed runtime.
Setting this option to true when creating a session ensures that no stateful resources (variables, queues, iterators, etc.) will be visible to any other session running on the same server, and those resources will be deleted when the session is closed. The default behavior, namely that all `tf.Variable` objects are shared by default and most other resources are shared when their `shared_name` attr is non-empty, is preserved. This change augments the semantics of the WorkerService.CreateWorkerSession RPC. Now, if the server_def in the request is empty, it implies that the worker should use its default ClusterSpec. Note that clusters created using ClusterSpec propagation always have isolated session state, and are unaffected by this change. PiperOrigin-RevId: 177173545
-rw-r--r--tensorflow/core/common_runtime/device.h2
-rw-r--r--tensorflow/core/common_runtime/renamed_device.cc11
-rw-r--r--tensorflow/core/common_runtime/renamed_device.h16
-rw-r--r--tensorflow/core/distributed_runtime/BUILD12
-rw-r--r--tensorflow/core/distributed_runtime/master_session.cc32
-rw-r--r--tensorflow/core/distributed_runtime/session_mgr.cc23
-rw-r--r--tensorflow/core/distributed_runtime/session_mgr.h4
-rw-r--r--tensorflow/core/distributed_runtime/session_mgr_test.cc66
-rw-r--r--tensorflow/core/distributed_runtime/worker.cc3
-rw-r--r--tensorflow/core/distributed_runtime/worker_cache_wrapper.h90
-rw-r--r--tensorflow/core/protobuf/config.proto6
-rw-r--r--tensorflow/core/protobuf/worker.proto4
-rw-r--r--tensorflow/python/client/session_clusterspec_prop_test.py43
-rw-r--r--tensorflow/python/training/server_lib_test.py89
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt4
15 files changed, 374 insertions, 31 deletions
diff --git a/tensorflow/core/common_runtime/device.h b/tensorflow/core/common_runtime/device.h
index 3912cd177b..d5a452a796 100644
--- a/tensorflow/core/common_runtime/device.h
+++ b/tensorflow/core/common_runtime/device.h
@@ -131,7 +131,7 @@ class Device : public DeviceBase {
OpSegment* op_segment() { return &op_seg_; }
// Returns the resource manager associated w/ this device.
- ResourceMgr* resource_manager() { return rmgr_; }
+ virtual ResourceMgr* resource_manager() { return rmgr_; }
// Summarizes the status of this Device, for debugging.
string DebugString() const { return ProtoDebugString(device_attributes_); }
diff --git a/tensorflow/core/common_runtime/renamed_device.cc b/tensorflow/core/common_runtime/renamed_device.cc
index fa9713735e..56766a8df4 100644
--- a/tensorflow/core/common_runtime/renamed_device.cc
+++ b/tensorflow/core/common_runtime/renamed_device.cc
@@ -21,7 +21,8 @@ namespace tensorflow {
/* static */
Device* RenamedDevice::NewRenamedDevice(const string& new_base,
Device* underlying,
- bool owns_underlying) {
+ bool owns_underlying,
+ bool isolate_session_state) {
DeviceNameUtils::ParsedName parsed_name;
CHECK(DeviceNameUtils::ParseFullName(new_base, &parsed_name));
DeviceNameUtils::ParsedName underlying_parsed_name =
@@ -35,15 +36,17 @@ Device* RenamedDevice::NewRenamedDevice(const string& new_base,
parsed_name.id);
DeviceAttributes attributes(underlying->attributes());
attributes.set_name(name);
- return new RenamedDevice(underlying, attributes, owns_underlying);
+ return new RenamedDevice(underlying, attributes, owns_underlying,
+ isolate_session_state);
}
RenamedDevice::RenamedDevice(Device* underlying,
const DeviceAttributes& attributes,
- bool owns_underlying)
+ bool owns_underlying, bool isolate_session_state)
: Device(underlying->env(), attributes),
underlying_(underlying),
- owns_underlying_(owns_underlying) {}
+ owns_underlying_(owns_underlying),
+ isolate_session_state_(isolate_session_state) {}
RenamedDevice::~RenamedDevice() {
if (owns_underlying_) {
diff --git a/tensorflow/core/common_runtime/renamed_device.h b/tensorflow/core/common_runtime/renamed_device.h
index 3103ca0751..c5c204d4fa 100644
--- a/tensorflow/core/common_runtime/renamed_device.h
+++ b/tensorflow/core/common_runtime/renamed_device.h
@@ -29,7 +29,9 @@ namespace tensorflow {
class RenamedDevice : public Device {
public:
static Device* NewRenamedDevice(const string& new_base, Device* underlying,
- bool owns_underlying);
+ bool owns_underlying,
+ bool isolate_session_state);
+
~RenamedDevice() override;
// Below are virtual methods defined on DeviceBase
@@ -113,11 +115,21 @@ class RenamedDevice : public Device {
return underlying_->FillContextMap(graph, device_context_map);
}
+ // Returns the resource manager associated w/ this device.
+ ResourceMgr* resource_manager() override {
+ if (isolate_session_state_) {
+ return Device::resource_manager();
+ } else {
+ return underlying_->resource_manager();
+ }
+ }
+
private:
RenamedDevice(Device* underlying, const DeviceAttributes& attributes,
- bool owns_underlying);
+ bool owns_underlying, bool isolate_session_state);
Device* const underlying_;
const bool owns_underlying_;
+ const bool isolate_session_state_;
};
} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD
index 93adc7ef4f..29164bbffe 100644
--- a/tensorflow/core/distributed_runtime/BUILD
+++ b/tensorflow/core/distributed_runtime/BUILD
@@ -140,6 +140,7 @@ cc_library(
hdrs = ["session_mgr.h"],
deps = [
":graph_mgr",
+ ":worker_cache_wrapper",
":worker_session",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:lib",
@@ -264,6 +265,17 @@ cc_library(
)
cc_library(
+ name = "worker_cache_wrapper",
+ hdrs = ["worker_cache_wrapper.h"],
+ deps = [
+ ":worker_cache",
+ ":worker_interface",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+cc_library(
name = "remote_device",
srcs = ["remote_device.cc"],
hdrs = ["remote_device.h"],
diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc
index 3379302b9b..03b65d8cba 100644
--- a/tensorflow/core/distributed_runtime/master_session.cc
+++ b/tensorflow/core/distributed_runtime/master_session.cc
@@ -1049,7 +1049,10 @@ Status MasterSession::Create(GraphDef* graph_def,
TF_RETURN_IF_ERROR(GraphExecutionState::MakeForBaseGraph(
graph_def, execution_options, &execution_state_));
}
- if (options.cluster_def != nullptr) {
+ // TODO(b/36574172): Remove these conditions when ClusterSpec
+ // propagation is supported in all servers.
+ if (options.cluster_def != nullptr ||
+ session_opts_.config.isolate_session_state()) {
should_delete_worker_sessions_ = true;
return CreateWorkerSessions(options);
}
@@ -1058,10 +1061,9 @@ Status MasterSession::Create(GraphDef* graph_def,
Status MasterSession::CreateWorkerSessions(
const WorkerCacheFactoryOptions& options) {
- CHECK(worker_cache_) << "CreateWorkerSessions should be called only with "
- << "dynamic cluster membership.";
std::vector<string> worker_names;
- worker_cache_->ListWorkers(&worker_names);
+ WorkerCacheInterface* worker_cache = get_worker_cache();
+ worker_cache->ListWorkers(&worker_names);
struct WorkerGroup {
// The worker name. (Not owned.)
@@ -1079,10 +1081,10 @@ Status MasterSession::CreateWorkerSessions(
std::vector<WorkerGroup> workers(worker_names.size());
// Release the workers.
- auto cleanup = gtl::MakeCleanup([this, &workers] {
+ auto cleanup = gtl::MakeCleanup([this, &workers, worker_cache] {
for (auto&& worker_group : workers) {
if (worker_group.worker != nullptr) {
- worker_cache_->ReleaseWorker(*worker_group.name, worker_group.worker);
+ worker_cache->ReleaseWorker(*worker_group.name, worker_group.worker);
}
}
});
@@ -1091,11 +1093,19 @@ Status MasterSession::CreateWorkerSessions(
// Create all the workers & kick off the computations.
for (size_t i = 0; i < worker_names.size(); ++i) {
workers[i].name = &worker_names[i];
- workers[i].worker = worker_cache_->CreateWorker(worker_names[i]);
+ workers[i].worker = worker_cache->CreateWorker(worker_names[i]);
workers[i].request.set_session_handle(handle_);
- *workers[i].request.mutable_server_def()->mutable_cluster() =
- *options.cluster_def;
- workers[i].request.mutable_server_def()->set_protocol(*options.protocol);
+ if (options.cluster_def) {
+ *workers[i].request.mutable_server_def()->mutable_cluster() =
+ *options.cluster_def;
+ workers[i].request.mutable_server_def()->set_protocol(*options.protocol);
+ // Session state is always isolated when ClusterSpec propagation
+ // is in use.
+ workers[i].request.set_isolate_session_state(true);
+ } else {
+ workers[i].request.set_isolate_session_state(
+ session_opts_.config.isolate_session_state());
+ }
DeviceNameUtils::ParsedName name;
if (!DeviceNameUtils::ParseFullName(worker_names[i], &name)) {
@@ -1162,7 +1172,7 @@ Status MasterSession::DeleteWorkerSessions() {
// Create all the workers & kick off the computations.
for (size_t i = 0; i < worker_names.size(); ++i) {
workers[i].name = &worker_names[i];
- workers[i].worker = worker_cache_->CreateWorker(worker_names[i]);
+ workers[i].worker = worker_cache->CreateWorker(worker_names[i]);
workers[i].request.set_session_handle(handle_);
}
diff --git a/tensorflow/core/distributed_runtime/session_mgr.cc b/tensorflow/core/distributed_runtime/session_mgr.cc
index b97749dc41..fabcbd00f5 100644
--- a/tensorflow/core/distributed_runtime/session_mgr.cc
+++ b/tensorflow/core/distributed_runtime/session_mgr.cc
@@ -20,7 +20,10 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/renamed_device.h"
#include "tensorflow/core/distributed_runtime/graph_mgr.h"
+#include "tensorflow/core/distributed_runtime/worker_cache_wrapper.h"
#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/protobuf/cluster.pb.h"
+#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
namespace tensorflow {
@@ -29,7 +32,10 @@ SessionMgr::SessionMgr(
std::unique_ptr<WorkerCacheInterface> default_worker_cache,
WorkerCacheFactory worker_cache_factory)
: worker_env_(worker_env),
- legacy_session_("", default_worker_name, std::move(default_worker_cache),
+ default_worker_cache_(std::move(default_worker_cache)),
+ legacy_session_("", default_worker_name,
+ std::unique_ptr<WorkerCacheInterface>(
+ new WorkerCacheWrapper(default_worker_cache_.get())),
std::unique_ptr<DeviceMgr>(worker_env->device_mgr),
std::unique_ptr<GraphMgr>(
new GraphMgr(worker_env, worker_env->device_mgr))),
@@ -41,7 +47,8 @@ string SessionMgr::WorkerNameFromServerDef(const ServerDef& server_def) {
}
Status SessionMgr::CreateSession(const string& session,
- const ServerDef& server_def) {
+ const ServerDef& server_def,
+ bool isolate_session_state) {
mutex_lock l(mu_);
if (session.empty()) {
return errors::InvalidArgument("Session must be non-empty.");
@@ -50,12 +57,18 @@ Status SessionMgr::CreateSession(const string& session,
const string worker_name = WorkerNameFromServerDef(server_def);
WorkerCacheInterface* worker_cache = nullptr;
- TF_RETURN_IF_ERROR(worker_cache_factory_(server_def, &worker_cache));
+ if (server_def.cluster().job().empty()) {
+ worker_cache = new WorkerCacheWrapper(default_worker_cache_.get());
+ } else {
+ TF_RETURN_IF_ERROR(worker_cache_factory_(server_def, &worker_cache));
+ }
+ CHECK(!worker_env_->local_devices.empty())
+ << "The WorkerEnv must have at least one device in `local_devices`.";
std::vector<Device*> renamed_devices;
for (Device* d : worker_env_->local_devices) {
- renamed_devices.push_back(
- RenamedDevice::NewRenamedDevice(worker_name, d, false));
+ renamed_devices.push_back(RenamedDevice::NewRenamedDevice(
+ worker_name, d, false, isolate_session_state));
}
std::unique_ptr<DeviceMgr> device_mgr(new DeviceMgr(renamed_devices));
diff --git a/tensorflow/core/distributed_runtime/session_mgr.h b/tensorflow/core/distributed_runtime/session_mgr.h
index c44bca7b7a..d85b6c3059 100644
--- a/tensorflow/core/distributed_runtime/session_mgr.h
+++ b/tensorflow/core/distributed_runtime/session_mgr.h
@@ -45,7 +45,8 @@ class SessionMgr {
~SessionMgr() {}
// Allocates state for a new session.
- Status CreateSession(const string& session, const ServerDef& server_def);
+ Status CreateSession(const string& session, const ServerDef& server_def,
+ bool isolate_session_state);
// Locates the worker session for a given session handle
WorkerSession* WorkerSessionForSession(const string& session);
@@ -71,6 +72,7 @@ class SessionMgr {
// legacy_session_ is deleted. Further, we must ensure that WorkerSession's
// device_mgr is deleted after WorkerSession's graph_mgr.
+ std::unique_ptr<WorkerCacheInterface> default_worker_cache_;
WorkerSession legacy_session_;
const WorkerCacheFactory worker_cache_factory_;
diff --git a/tensorflow/core/distributed_runtime/session_mgr_test.cc b/tensorflow/core/distributed_runtime/session_mgr_test.cc
index 7132f123a5..ffe4809f2b 100644
--- a/tensorflow/core/distributed_runtime/session_mgr_test.cc
+++ b/tensorflow/core/distributed_runtime/session_mgr_test.cc
@@ -22,14 +22,36 @@ limitations under the License.
namespace tensorflow {
+class FakeDevice : public Device {
+ private:
+ explicit FakeDevice(const DeviceAttributes& device_attributes)
+ : Device(nullptr, device_attributes) {}
+
+ public:
+ Status Sync() override { return errors::Unimplemented("FakeDevice::Sync()"); }
+
+ Allocator* GetAllocator(AllocatorAttributes attr) override { return nullptr; }
+
+ static std::unique_ptr<Device> MakeCPU(const string& name) {
+ DeviceAttributes device_attributes;
+ device_attributes.set_name(name);
+ device_attributes.set_device_type(DeviceType("FakeCPU").type());
+ return std::unique_ptr<Device>(new FakeDevice(device_attributes));
+ }
+};
+
class SessionMgrTest : public ::testing::Test {
protected:
SessionMgrTest()
- : mgr_(&env_, "/job:mnist/replica:0/task:0",
- std::unique_ptr<WorkerCacheInterface>(),
- factory_),
- legacy_session_(mgr_.WorkerSessionForSession("novel_session_id")) {}
+ : device_(FakeDevice::MakeCPU(
+ "/job:mnist/replica:0/task:0/device:fakecpu:0")),
+ mgr_(&env_, "/job:mnist/replica:0/task:0",
+ std::unique_ptr<WorkerCacheInterface>(), factory_),
+ legacy_session_(mgr_.WorkerSessionForSession("novel_session_id")) {
+ env_.local_devices = {device_.get()};
+ }
+ std::unique_ptr<Device> device_;
WorkerEnv env_;
SessionMgr::WorkerCacheFactory factory_ =
[](const ServerDef& server_def, WorkerCacheInterface** worker_cache) {
@@ -42,14 +64,48 @@ class SessionMgrTest : public ::testing::Test {
TEST_F(SessionMgrTest, CreateSessionSimple) {
ServerDef server_def;
+ server_def.set_job_name("worker");
+ server_def.set_task_index(3);
+
string session_handle = "test_session_handle";
- TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def));
+ TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def, true));
WorkerSession* session = mgr_.WorkerSessionForSession(session_handle);
EXPECT_NE(nullptr, session) << "Session for " << session_handle << "was null";
EXPECT_NE(mgr_.LegacySession(), session);
TF_EXPECT_OK(mgr_.DeleteSession(session_handle));
}
+TEST_F(SessionMgrTest, CreateSessionIsolateSessionState) {
+ ServerDef server_def;
+ server_def.set_job_name("worker");
+ server_def.set_task_index(3);
+
+ TF_EXPECT_OK(mgr_.CreateSession("handle_1", server_def, false));
+ WorkerSession* session_1 = mgr_.WorkerSessionForSession("handle_1");
+ std::vector<Device*> devices_1 = session_1->device_mgr->ListDevices();
+ EXPECT_EQ(1, devices_1.size());
+
+ TF_EXPECT_OK(mgr_.CreateSession("handle_2", server_def, false));
+ WorkerSession* session_2 = mgr_.WorkerSessionForSession("handle_2");
+ std::vector<Device*> devices_2 = session_2->device_mgr->ListDevices();
+ EXPECT_EQ(1, devices_2.size());
+
+ TF_EXPECT_OK(mgr_.CreateSession("handle_3", server_def, true));
+ WorkerSession* session_3 = mgr_.WorkerSessionForSession("handle_3");
+ std::vector<Device*> devices_3 = session_3->device_mgr->ListDevices();
+ EXPECT_EQ(1, devices_3.size());
+
+ TF_EXPECT_OK(mgr_.CreateSession("handle_4", server_def, true));
+ WorkerSession* session_4 = mgr_.WorkerSessionForSession("handle_4");
+ std::vector<Device*> devices_4 = session_4->device_mgr->ListDevices();
+ EXPECT_EQ(1, devices_4.size());
+
+ EXPECT_EQ(devices_1[0]->resource_manager(), devices_2[0]->resource_manager());
+ EXPECT_NE(devices_1[0]->resource_manager(), devices_3[0]->resource_manager());
+ EXPECT_NE(devices_1[0]->resource_manager(), devices_4[0]->resource_manager());
+ EXPECT_NE(devices_3[0]->resource_manager(), devices_4[0]->resource_manager());
+}
+
TEST_F(SessionMgrTest, LegacySession) {
ServerDef server_def;
string session_handle = "";
diff --git a/tensorflow/core/distributed_runtime/worker.cc b/tensorflow/core/distributed_runtime/worker.cc
index 8bf87923ed..6cd92f5fe7 100644
--- a/tensorflow/core/distributed_runtime/worker.cc
+++ b/tensorflow/core/distributed_runtime/worker.cc
@@ -44,7 +44,8 @@ void Worker::CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request,
CreateWorkerSessionResponse* response,
StatusCallback done) {
Status s = env_->session_mgr->CreateSession(request->session_handle(),
- request->server_def());
+ request->server_def(),
+ request->isolate_session_state());
done(s);
}
diff --git a/tensorflow/core/distributed_runtime/worker_cache_wrapper.h b/tensorflow/core/distributed_runtime/worker_cache_wrapper.h
new file mode 100644
index 0000000000..43c3b6285b
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/worker_cache_wrapper.h
@@ -0,0 +1,90 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_WRAPPER_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_WRAPPER_H_
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/distributed_runtime/worker_cache.h"
+
+namespace tensorflow {
+
+class WorkerCacheWrapper : public WorkerCacheInterface {
+ public:
+ WorkerCacheWrapper(WorkerCacheInterface* wrapped) : wrapped_(wrapped) {}
+
+ // Updates *workers with strings naming the remote worker tasks to
+ // which open channels have been established.
+ virtual void ListWorkers(std::vector<string>* workers) const {
+ return wrapped_->ListWorkers(workers);
+ }
+
+ // If "target" names a remote task for which an RPC channel exists
+ // or can be constructed, returns a pointer to a WorkerInterface object
+ // wrapping that channel. The returned value must be destroyed by
+ // calling `this->ReleaseWorker(target, ret)`
+ // TODO(mrry): rename this to GetOrCreateWorker() or something that
+ // makes it more obvious that this method returns a potentially
+ // shared object.
+ virtual WorkerInterface* CreateWorker(const string& target) {
+ return wrapped_->CreateWorker(target);
+ }
+
+ // Release a worker previously returned by this->CreateWorker(target).
+ //
+ // TODO(jeff,sanjay): Consider moving target into WorkerInterface.
+ // TODO(jeff,sanjay): Unify all worker-cache impls and factor out a
+ // per-rpc-subsystem WorkerInterface creator.
+ virtual void ReleaseWorker(const string& target, WorkerInterface* worker) {
+ return wrapped_->ReleaseWorker(target, worker);
+ }
+
+ // Set *locality with the DeviceLocality of the specified remote device
+ // within its local environment. Returns true if *locality
+ // was set, using only locally cached data. Returns false
+ // if status data for that device was not available. Never blocks.
+ virtual bool GetDeviceLocalityNonBlocking(const string& device,
+ DeviceLocality* locality) {
+ return wrapped_->GetDeviceLocalityNonBlocking(device, locality);
+ }
+
+ // Set *locality with the DeviceLocality of the specified remote device
+ // within its local environment. Callback gets Status::OK if *locality
+ // was set.
+ virtual void GetDeviceLocalityAsync(const string& device,
+ DeviceLocality* locality,
+ StatusCallback done) {
+ return wrapped_->GetDeviceLocalityAsync(device, locality, std::move(done));
+ }
+
+ // Start/stop logging activity.
+ virtual void SetLogging(bool active) { wrapped_->SetLogging(active); }
+
+ // Discard any saved log data.
+ virtual void ClearLogs() { wrapped_->ClearLogs(); }
+
+ // Return logs for the identified step in *ss. Any returned data will no
+ // longer be stored.
+ virtual bool RetrieveLogs(int64 step_id, StepStats* ss) {
+ return wrapped_->RetrieveLogs(step_id, ss);
+ }
+
+ private:
+ WorkerCacheInterface* wrapped_; // Not owned.
+};
+} // namespace tensorflow
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_WRAPPER_H_
diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto
index a956aab3dc..1916316245 100644
--- a/tensorflow/core/protobuf/config.proto
+++ b/tensorflow/core/protobuf/config.proto
@@ -303,7 +303,11 @@ message ConfigProto {
// Optional list of all workers to use in this session.
ClusterDef cluster_def = 14;
- // Next: 15
+ // If true, any resources such as Variables used in the session will not be
+ // shared with other sessions.
+ bool isolate_session_state = 15;
+
+ // Next: 16
};
// Options for a single Run() call.
diff --git a/tensorflow/core/protobuf/worker.proto b/tensorflow/core/protobuf/worker.proto
index e7b3f36fcc..385e2dd163 100644
--- a/tensorflow/core/protobuf/worker.proto
+++ b/tensorflow/core/protobuf/worker.proto
@@ -59,6 +59,10 @@ message CreateWorkerSessionRequest {
// Defines the configuration of a TensorFlow worker.
ServerDef server_def = 2;
+
+ // If true, any resources such as Variables used in the session will not be
+ // shared with other sessions.
+ bool isolate_session_state = 3;
}
message CreateWorkerSessionResponse {
diff --git a/tensorflow/python/client/session_clusterspec_prop_test.py b/tensorflow/python/client/session_clusterspec_prop_test.py
index 28a4dd27a7..c85b22eb15 100644
--- a/tensorflow/python/client/session_clusterspec_prop_test.py
+++ b/tensorflow/python/client/session_clusterspec_prop_test.py
@@ -29,6 +29,7 @@ from tensorflow.python.client import session
from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
@@ -415,6 +416,48 @@ class SessionClusterSpecPropagationTest(test_util.TensorFlowTestCase):
node_stats.node_name.startswith('Const')
]), run_metadata)
+ def testClusterSpecPropagationIsolation(self):
+ """Test that two sessions using ClusterSpec propagation are isolated."""
+ server = server_lib.Server.create_local_server()
+ init_value = array_ops.placeholder(dtypes.int32, shape=[])
+ v = variables.Variable(init_value)
+
+ cluster_def = cluster_pb2.ClusterDef()
+ job = cluster_def.job.add()
+ job.name = 'worker'
+ job.tasks[0] = server.target[len('grpc://'):]
+ config = config_pb2.ConfigProto(cluster_def=cluster_def)
+
+ sess1 = session.Session(server.target, config=config)
+ sess2 = session.Session(server.target, config=config)
+
+ # Initially, the variable is uninitialized in both sessions.
+ with self.assertRaises(errors.FailedPreconditionError):
+ sess1.run(v)
+ with self.assertRaises(errors.FailedPreconditionError):
+ sess2.run(v)
+
+ # An update in sess1 should be visible in sess1 only.
+ sess1.run(v.initializer, feed_dict={init_value: 37})
+ self.assertEqual(37, sess1.run(v))
+ with self.assertRaises(errors.FailedPreconditionError):
+ sess2.run(v)
+
+ # An update in sess2 should be visible in sess2 only.
+ sess2.run(v.initializer, feed_dict={init_value: 86})
+ self.assertEqual(37, sess1.run(v))
+ self.assertEqual(86, sess2.run(v))
+
+ # Closing sess2 has no effect on the state of sess1.
+ sess2.close()
+ self.assertEqual(37, sess1.run(v))
+
+ # Subsequent sessions will not see the state of existing sessions.
+ sess3 = session.Session(server.target, config=config)
+ self.assertEqual(37, sess1.run(v))
+ with self.assertRaises(errors.FailedPreconditionError):
+ sess3.run(v)
+
@test_util.disable_c_api # Partial runs don't work with C API
def testClusterSpecPropagationPartialRun(self):
"""Test successful partial run with ClusterSpec propagation."""
diff --git a/tensorflow/python/training/server_lib_test.py b/tensorflow/python/training/server_lib_test.py
index 0a8ec4901c..26aac787ed 100644
--- a/tensorflow/python/training/server_lib_test.py
+++ b/tensorflow/python/training/server_lib_test.py
@@ -241,6 +241,95 @@ class GrpcServerTest(test.TestCase):
queue_runner_impl.start_queue_runners(sess)
sess.run(var.assign(3.0))
+ def testIsolateSessionState(self):
+ server = self._cached_server
+
+ init_value = array_ops.placeholder(dtypes.int32)
+ v = variables.Variable(init_value, validate_shape=False, name="v")
+
+ sharing_config = config_pb2.ConfigProto(isolate_session_state=False)
+ sharing_sess_0 = session.Session(server.target, config=sharing_config)
+ sharing_sess_1 = session.Session(server.target, config=sharing_config)
+
+ isolate_config = config_pb2.ConfigProto(isolate_session_state=True)
+ isolate_sess_0 = session.Session(server.target, config=isolate_config)
+ isolate_sess_1 = session.Session(server.target, config=isolate_config)
+
+ # Initially all variables are initialized.
+ for sess in [sharing_sess_0, sharing_sess_1,
+ isolate_sess_0, isolate_sess_1]:
+ with self.assertRaises(errors_impl.FailedPreconditionError):
+ sess.run(v)
+
+ # Shared sessions will see each other's updates, but isolated sessions
+ # will not.
+ sharing_sess_0.run(v.initializer, feed_dict={init_value: 86})
+ self.assertAllEqual(86, sharing_sess_0.run(v))
+ self.assertAllEqual(86, sharing_sess_1.run(v))
+ with self.assertRaises(errors_impl.FailedPreconditionError):
+ isolate_sess_0.run(v)
+ with self.assertRaises(errors_impl.FailedPreconditionError):
+ isolate_sess_1.run(v)
+
+ # Changing the shape works because `validate_shape` is False.
+ sharing_sess_1.run(v.initializer, feed_dict={init_value: [86, 99]})
+ self.assertAllEqual([86, 99], sharing_sess_0.run(v))
+ self.assertAllEqual([86, 99], sharing_sess_1.run(v))
+ with self.assertRaises(errors_impl.FailedPreconditionError):
+ isolate_sess_0.run(v)
+ with self.assertRaises(errors_impl.FailedPreconditionError):
+ isolate_sess_1.run(v)
+
+ # Initializing in an isolated session will only affect the state in that
+ # session.
+ isolate_sess_0.run(v.initializer, feed_dict={init_value: 37})
+ self.assertAllEqual([86, 99], sharing_sess_0.run(v))
+ self.assertAllEqual([86, 99], sharing_sess_1.run(v))
+ self.assertAllEqual(37, isolate_sess_0.run(v))
+ with self.assertRaises(errors_impl.FailedPreconditionError):
+ isolate_sess_1.run(v)
+
+ # Isolated sessions can have different shapes for the same variable.
+ isolate_sess_1.run(v.initializer, feed_dict={init_value: [19, 86]})
+ self.assertAllEqual([86, 99], sharing_sess_0.run(v))
+ self.assertAllEqual([86, 99], sharing_sess_1.run(v))
+ self.assertAllEqual(37, isolate_sess_0.run(v))
+ self.assertAllEqual([19, 86], isolate_sess_1.run(v))
+
+ def testShapeChangingIsolateState(self):
+ server = self._cached_server
+ sharing_config = config_pb2.ConfigProto(isolate_session_state=False)
+ isolate_config = config_pb2.ConfigProto(isolate_session_state=True)
+
+ with ops.Graph().as_default():
+ w_vector = variables.Variable([1, 2, 3], name="w")
+ with session.Session(server.target, config=sharing_config) as sess:
+ with self.assertRaises(errors_impl.FailedPreconditionError):
+ sess.run(w_vector)
+ sess.run(w_vector.initializer)
+ self.assertAllEqual([1, 2, 3], sess.run(w_vector))
+
+ with ops.Graph().as_default():
+ w_vector = variables.Variable([4, 5, 6], name="w")
+ with session.Session(server.target, config=sharing_config) as sess:
+ self.assertAllEqual([1, 2, 3], sess.run(w_vector))
+ sess.run(w_vector.initializer)
+ self.assertAllEqual([4, 5, 6], sess.run(w_vector))
+
+ with ops.Graph().as_default():
+ w_scalar = variables.Variable(86, name="w")
+ with session.Session(server.target, config=sharing_config) as sess:
+ with self.assertRaises(errors_impl.InvalidArgumentError):
+ sess.run(w_scalar.initializer)
+
+ with ops.Graph().as_default():
+ w_scalar = variables.Variable(37, name="w")
+ with session.Session(server.target, config=isolate_config) as sess:
+ with self.assertRaises(errors_impl.FailedPreconditionError):
+ sess.run(w_scalar)
+ sess.run(w_scalar.initializer)
+ self.assertAllEqual(37, sess.run(w_scalar))
+
class ServerDefTest(test.TestCase):
diff --git a/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt b/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt
index da6af3919e..009d64aed0 100644
--- a/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt
@@ -47,6 +47,10 @@ tf_class {
mtype: "<type \'int\'>"
}
member {
+ name: "ISOLATE_SESSION_STATE_FIELD_NUMBER"
+ mtype: "<type \'int\'>"
+ }
+ member {
name: "LOG_DEVICE_PLACEMENT_FIELD_NUMBER"
mtype: "<type \'int\'>"
}