aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Brennan Saeta <saeta@google.com>2017-05-04 19:43:48 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-04 21:08:47 -0700
commitf28935a7d280b6ba75fe93fe35783d87b9cc2ec9 (patch)
treedb03a72f0dd29e8e09dcf7805684268b3d40a54c /tensorflow
parentefa08d80a53a95ce6b8beb61ec86a275aed6b6c7 (diff)
Implement ClusterSpec Propagation in TF Master
ClusterSpec propagation is a capability upgrade for TensorFlow that should make it much easier to (1) build distributed TensorFlow clusters, and (2) handle node failures. The ClusterSpec propagation capability allows TensorFlow workers to be booted independently of each other, and with no knowledge about others. The client can then construct a ClusterDef (ClusterSpec), and then send it to the TF master at session creation. The master in turn then propagates the ClusterDef along to all of the workers. Change: 155159972
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/compiler/jit/xla_device.cc2
-rw-r--r--tensorflow/compiler/tf2xla/xla_compilation_device.cc3
-rw-r--r--tensorflow/contrib/cmake/tf_core_framework.cmake1
-rw-r--r--tensorflow/contrib/makefile/proto_text_pb_cc_files.txt1
-rw-r--r--tensorflow/contrib/makefile/proto_text_pb_h_files.txt1
-rw-r--r--tensorflow/contrib/makefile/tf_pb_text_files.txt1
-rw-r--r--tensorflow/contrib/makefile/tf_proto_files.txt1
-rw-r--r--tensorflow/core/BUILD1
-rw-r--r--tensorflow/core/common_runtime/device.cc3
-rw-r--r--tensorflow/core/common_runtime/device.h3
-rw-r--r--tensorflow/core/common_runtime/device_mgr.cc19
-rw-r--r--tensorflow/core/common_runtime/device_mgr.h2
-rw-r--r--tensorflow/core/common_runtime/device_set.h5
-rw-r--r--tensorflow/core/common_runtime/device_set_test.cc3
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device.cc7
-rw-r--r--tensorflow/core/common_runtime/local_device.cc6
-rw-r--r--tensorflow/core/common_runtime/local_device.h4
-rw-r--r--tensorflow/core/common_runtime/renamed_device.cc54
-rw-r--r--tensorflow/core/common_runtime/renamed_device.h119
-rw-r--r--tensorflow/core/common_runtime/simple_placer_test.cc2
-rw-r--r--tensorflow/core/common_runtime/threadpool_device.cc6
-rw-r--r--tensorflow/core/distributed_runtime/BUILD4
-rw-r--r--tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc102
-rw-r--r--tensorflow/core/distributed_runtime/base_rendezvous_mgr.h43
-rw-r--r--tensorflow/core/distributed_runtime/graph_mgr.cc37
-rw-r--r--tensorflow/core/distributed_runtime/graph_mgr.h11
-rw-r--r--tensorflow/core/distributed_runtime/master.cc121
-rw-r--r--tensorflow/core/distributed_runtime/master_env.h34
-rw-r--r--tensorflow/core/distributed_runtime/master_session.cc151
-rw-r--r--tensorflow/core/distributed_runtime/master_session.h21
-rw-r--r--tensorflow/core/distributed_runtime/message_wrappers.cc21
-rw-r--r--tensorflow/core/distributed_runtime/message_wrappers.h11
-rw-r--r--tensorflow/core/distributed_runtime/remote_device.cc49
-rw-r--r--tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h22
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc83
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h15
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_session.cc7
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc14
-rw-r--r--tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc100
-rw-r--r--tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h13
-rw-r--r--tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc19
-rw-r--r--tensorflow/core/distributed_runtime/session_mgr.cc108
-rw-r--r--tensorflow/core/distributed_runtime/session_mgr.h44
-rw-r--r--tensorflow/core/distributed_runtime/session_mgr_test.cc81
-rw-r--r--tensorflow/core/distributed_runtime/worker.cc32
-rw-r--r--tensorflow/core/distributed_runtime/worker_env.h11
-rw-r--r--tensorflow/core/distributed_runtime/worker_interface.h5
-rw-r--r--tensorflow/core/distributed_runtime/worker_session.cc84
-rw-r--r--tensorflow/core/distributed_runtime/worker_session.h12
-rw-r--r--tensorflow/core/framework/device_base.h8
-rw-r--r--tensorflow/core/protobuf/cluster.proto82
-rw-r--r--tensorflow/core/protobuf/config.proto6
-rw-r--r--tensorflow/core/protobuf/master.proto3
-rw-r--r--tensorflow/core/protobuf/tensorflow_server.proto64
-rw-r--r--tensorflow/core/protobuf/worker.proto12
-rw-r--r--tensorflow/python/__init__.py2
-rw-r--r--tensorflow/python/client/session_test.py267
-rw-r--r--tensorflow/python/training/server_lib.py7
-rw-r--r--tensorflow/python/training/training.py30
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.train.-cluster-def.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.train.-job-def.-tasks-entry.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.train.-job-def.pbtxt2
63 files changed, 1396 insertions, 594 deletions
diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc
index 93f487c36c..5e336c5287 100644
--- a/tensorflow/compiler/jit/xla_device.cc
+++ b/tensorflow/compiler/jit/xla_device.cc
@@ -125,7 +125,7 @@ XlaDevice::XlaDevice(const SessionOptions& options,
const DeviceType& jit_device_name,
perftools::gputools::Platform* platform,
Allocator* xla_allocator)
- : LocalDevice(options, attrs, xla_allocator),
+ : LocalDevice(options, attrs),
device_ordinal_(device_ordinal),
jit_device_name_(jit_device_name),
xla_allocator_(xla_allocator),
diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc
index d86e741b69..362a101895 100644
--- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc
+++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc
@@ -76,8 +76,7 @@ XlaCompilationDevice::XlaCompilationDevice(const SessionOptions& options,
options,
Device::BuildDeviceAttributes(
"", type, Bytes(256 << 20), DeviceLocality(),
- strings::StrCat("device: XLA compilation device ", type.type())),
- cpu_allocator()),
+ strings::StrCat("device: XLA compilation device ", type.type()))),
allocator_(new XlaCompilationAllocator()) {}
XlaCompilationDevice::~XlaCompilationDevice() {}
diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake
index 6fd1ae0814..560e45fc13 100644
--- a/tensorflow/contrib/cmake/tf_core_framework.cmake
+++ b/tensorflow/contrib/cmake/tf_core_framework.cmake
@@ -118,6 +118,7 @@ set(tf_proto_text_srcs
"tensorflow/core/framework/types.proto"
"tensorflow/core/framework/versions.proto"
"tensorflow/core/lib/core/error_codes.proto"
+ "tensorflow/core/protobuf/cluster.proto"
"tensorflow/core/protobuf/config.proto"
"tensorflow/core/protobuf/debug.proto"
"tensorflow/core/protobuf/rewriter_config.proto"
diff --git a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt
index c0969e6dee..2f1fcb149e 100644
--- a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt
+++ b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt
@@ -7,6 +7,7 @@ tensorflow/core/protobuf/saver.pb.cc
tensorflow/core/protobuf/queue_runner.pb.cc
tensorflow/core/protobuf/named_tensor.pb.cc
tensorflow/core/protobuf/meta_graph.pb.cc
+tensorflow/core/protobuf/cluster.pb.cc
tensorflow/core/protobuf/config.pb.cc
tensorflow/core/protobuf/rewriter_config.pb.cc
tensorflow/core/protobuf/debug.pb.cc
diff --git a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt
index 132b477596..6087a45168 100644
--- a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt
+++ b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt
@@ -7,6 +7,7 @@ tensorflow/core/protobuf/saver.pb.h
tensorflow/core/protobuf/queue_runner.pb.h
tensorflow/core/protobuf/named_tensor.pb.h
tensorflow/core/protobuf/meta_graph.pb.h
+tensorflow/core/protobuf/cluster.pb.h
tensorflow/core/protobuf/config.pb.h
tensorflow/core/protobuf/debug.pb.h
tensorflow/core/protobuf/rewriter_config.pb.h
diff --git a/tensorflow/contrib/makefile/tf_pb_text_files.txt b/tensorflow/contrib/makefile/tf_pb_text_files.txt
index f1da05e4c6..c39257ffa9 100644
--- a/tensorflow/contrib/makefile/tf_pb_text_files.txt
+++ b/tensorflow/contrib/makefile/tf_pb_text_files.txt
@@ -1,6 +1,7 @@
tensorflow/core/util/saved_tensor_slice.pb_text.cc
tensorflow/core/util/memmapped_file_system.pb_text.cc
tensorflow/core/protobuf/saver.pb_text.cc
+tensorflow/core/protobuf/cluster.pb_text.cc
tensorflow/core/protobuf/config.pb_text.cc
tensorflow/core/protobuf/debug.pb_text.cc
tensorflow/core/protobuf/rewriter_config.pb_text.cc
diff --git a/tensorflow/contrib/makefile/tf_proto_files.txt b/tensorflow/contrib/makefile/tf_proto_files.txt
index 2a78ea6101..5eadf5d55b 100644
--- a/tensorflow/contrib/makefile/tf_proto_files.txt
+++ b/tensorflow/contrib/makefile/tf_proto_files.txt
@@ -7,6 +7,7 @@ tensorflow/core/protobuf/saver.proto
tensorflow/core/protobuf/queue_runner.proto
tensorflow/core/protobuf/named_tensor.proto
tensorflow/core/protobuf/meta_graph.proto
+tensorflow/core/protobuf/cluster.proto
tensorflow/core/protobuf/config.proto
tensorflow/core/protobuf/debug.proto
tensorflow/core/protobuf/rewriter_config.proto
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 435618ace7..9d0c6a6c3e 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -154,6 +154,7 @@ CORE_PROTO_SRCS = [
"framework/versions.proto",
"lib/core/error_codes.proto",
"protobuf/config.proto",
+ "protobuf/cluster.proto",
"protobuf/debug.proto",
"protobuf/queue_runner.proto",
"protobuf/rewriter_config.proto",
diff --git a/tensorflow/core/common_runtime/device.cc b/tensorflow/core/common_runtime/device.cc
index 78649afeb9..aa8a2d989b 100644
--- a/tensorflow/core/common_runtime/device.cc
+++ b/tensorflow/core/common_runtime/device.cc
@@ -23,8 +23,7 @@ limitations under the License.
namespace tensorflow {
-Device::Device(Env* env, const DeviceAttributes& device_attributes,
- Allocator* device_allocator)
+Device::Device(Env* env, const DeviceAttributes& device_attributes)
: DeviceBase(env), device_attributes_(device_attributes) {
CHECK(DeviceNameUtils::ParseFullName(name(), &parsed_name_))
<< "Invalid device name: " << name();
diff --git a/tensorflow/core/common_runtime/device.h b/tensorflow/core/common_runtime/device.h
index 07c6bdd683..c0e58f143e 100644
--- a/tensorflow/core/common_runtime/device.h
+++ b/tensorflow/core/common_runtime/device.h
@@ -53,8 +53,7 @@ namespace tensorflow {
class Device : public DeviceBase {
public:
- Device(Env* env, const DeviceAttributes& device_attributes,
- Allocator* device_allocator);
+ Device(Env* env, const DeviceAttributes& device_attributes);
~Device() override;
// Full name of this device (see top comment).
diff --git a/tensorflow/core/common_runtime/device_mgr.cc b/tensorflow/core/common_runtime/device_mgr.cc
index 7807656cb2..31f12d4833 100644
--- a/tensorflow/core/common_runtime/device_mgr.cc
+++ b/tensorflow/core/common_runtime/device_mgr.cc
@@ -29,10 +29,18 @@ DeviceMgr::DeviceMgr(const std::vector<Device*>& devices)
for (Device* d : devices) {
devices_.push_back(d);
- // Register under both the full name and the local name.
+ // Register under the (1) full name, (2) canonical name, and (3) local name.
string full_name = d->name();
device_map_[CopyToBackingStore(full_name)] = d;
+ DeviceNameUtils::ParsedName parsed_name = d->parsed_name();
+ if (parsed_name.has_job && parsed_name.has_replica &&
+ parsed_name.has_task && parsed_name.has_type && parsed_name.has_id) {
+ string canonical_name = DeviceNameUtils::FullName(
+ parsed_name.job, parsed_name.replica, parsed_name.task,
+ parsed_name.type, parsed_name.id);
+ device_map_[CopyToBackingStore(canonical_name)] = d;
+ }
string lname = DeviceNameUtils::LocalName(d->name());
device_map_[CopyToBackingStore(lname)] = d;
device_type_counts_[d->device_type()]++;
@@ -40,7 +48,8 @@ DeviceMgr::DeviceMgr(const std::vector<Device*>& devices)
}
DeviceMgr::~DeviceMgr() {
- for (auto p : devices_) delete p;
+ // TODO(b/37437134): Remove destructor after converting to std::unique_ptr.
+ for (Device* p : devices_) delete p;
}
StringPiece DeviceMgr::CopyToBackingStore(StringPiece s) {
@@ -85,6 +94,12 @@ Status DeviceMgr::LookupDevice(StringPiece name, Device** device) const {
Status s;
auto iter = device_map_.find(name);
if (iter == device_map_.end()) {
+ std::vector<StringPiece> device_names;
+ for (auto&& itr : device_map_) {
+ device_names.push_back(itr.first);
+ }
+ LOG(WARNING) << "Unknown device: " << name
+ << " all devices: " << str_util::Join(device_names, ", ");
return errors::InvalidArgument(name, " unknown device.");
}
*device = iter->second;
diff --git a/tensorflow/core/common_runtime/device_mgr.h b/tensorflow/core/common_runtime/device_mgr.h
index bb1ed72640..d16681ac59 100644
--- a/tensorflow/core/common_runtime/device_mgr.h
+++ b/tensorflow/core/common_runtime/device_mgr.h
@@ -36,6 +36,7 @@ class DeviceMgr {
public:
// Takes ownership of each device in 'devices'.
// TODO(zhifengc): Other initialization information.
+ // TODO(b/37437134): Use std::unique_ptr's to track ownership.
explicit DeviceMgr(const std::vector<Device*>& devices);
~DeviceMgr();
@@ -61,6 +62,7 @@ class DeviceMgr {
int NumDeviceType(const string& type) const;
private:
+ // TODO(b/37437134): Use std::unique_ptr's to track ownership.
typedef gtl::InlinedVector<Device*, 8> DeviceVec;
DeviceVec devices_;
diff --git a/tensorflow/core/common_runtime/device_set.h b/tensorflow/core/common_runtime/device_set.h
index b0540dfa95..4cd56e583c 100644
--- a/tensorflow/core/common_runtime/device_set.h
+++ b/tensorflow/core/common_runtime/device_set.h
@@ -39,7 +39,10 @@ class DeviceSet {
// Set the device designated as the "client". This device
// must also be registered via AddDevice().
- void set_client_device(Device* device) { client_device_ = device; }
+ void set_client_device(Device* device) {
+ DCHECK(client_device_ == nullptr);
+ client_device_ = device;
+ }
// Returns a pointer to the device designated as the "client".
Device* client_device() const { return client_device_; }
diff --git a/tensorflow/core/common_runtime/device_set_test.cc b/tensorflow/core/common_runtime/device_set_test.cc
index ff20ee94a7..0507076c8c 100644
--- a/tensorflow/core/common_runtime/device_set_test.cc
+++ b/tensorflow/core/common_runtime/device_set_test.cc
@@ -27,8 +27,7 @@ namespace {
static Device* Dev(const char* type, const char* name) {
class FakeDevice : public Device {
public:
- explicit FakeDevice(const DeviceAttributes& attr)
- : Device(nullptr, attr, nullptr) {}
+ explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
Status Sync() override { return Status::OK(); }
Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; }
};
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc
index 0e2343cfe3..02f70d835d 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc
@@ -179,10 +179,9 @@ BaseGPUDevice::BaseGPUDevice(const SessionOptions& options, const string& name,
int gpu_id, const string& physical_device_desc,
Allocator* gpu_allocator, Allocator* cpu_allocator,
bool sync_every_op, int32 max_streams)
- : LocalDevice(options,
- Device::BuildDeviceAttributes(name, DEVICE_GPU, memory_limit,
- locality, physical_device_desc),
- gpu_allocator),
+ : LocalDevice(options, Device::BuildDeviceAttributes(name, DEVICE_GPU,
+ memory_limit, locality,
+ physical_device_desc)),
gpu_allocator_(gpu_allocator),
cpu_allocator_(cpu_allocator),
gpu_id_(gpu_id),
diff --git a/tensorflow/core/common_runtime/local_device.cc b/tensorflow/core/common_runtime/local_device.cc
index 0a6342ed73..3f7c9f68db 100644
--- a/tensorflow/core/common_runtime/local_device.cc
+++ b/tensorflow/core/common_runtime/local_device.cc
@@ -60,10 +60,8 @@ struct LocalDevice::EigenThreadPoolInfo {
};
LocalDevice::LocalDevice(const SessionOptions& options,
- const DeviceAttributes& attributes,
- Allocator* device_allocator)
- : Device(options.env, attributes, device_allocator),
- owned_tp_info_(nullptr) {
+ const DeviceAttributes& attributes)
+ : Device(options.env, attributes), owned_tp_info_(nullptr) {
// If we're running on the CPU, log warnings if we're not compiled using the
// best flags for performance.
port::WarnAboutUnusedCPUFeatures();
diff --git a/tensorflow/core/common_runtime/local_device.h b/tensorflow/core/common_runtime/local_device.h
index d1c27c6248..84a4f66db4 100644
--- a/tensorflow/core/common_runtime/local_device.h
+++ b/tensorflow/core/common_runtime/local_device.h
@@ -33,8 +33,8 @@ struct SessionOptions;
// GPUDevice into more 'process-wide' abstractions.
class LocalDevice : public Device {
public:
- LocalDevice(const SessionOptions& options, const DeviceAttributes& attributes,
- Allocator* device_allocator);
+ LocalDevice(const SessionOptions& options,
+ const DeviceAttributes& attributes);
~LocalDevice() override;
private:
diff --git a/tensorflow/core/common_runtime/renamed_device.cc b/tensorflow/core/common_runtime/renamed_device.cc
new file mode 100644
index 0000000000..fa9713735e
--- /dev/null
+++ b/tensorflow/core/common_runtime/renamed_device.cc
@@ -0,0 +1,54 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/renamed_device.h"
+
+namespace tensorflow {
+
+// TODO(saeta): Convert to returning a std::unique_ptr?
+/* static */
+Device* RenamedDevice::NewRenamedDevice(const string& new_base,
+ Device* underlying,
+ bool owns_underlying) {
+ DeviceNameUtils::ParsedName parsed_name;
+ CHECK(DeviceNameUtils::ParseFullName(new_base, &parsed_name));
+ DeviceNameUtils::ParsedName underlying_parsed_name =
+ underlying->parsed_name();
+ CHECK(underlying_parsed_name.has_type);
+ CHECK(underlying_parsed_name.has_id);
+ parsed_name.type = underlying_parsed_name.type;
+ parsed_name.id = underlying_parsed_name.id;
+ string name = DeviceNameUtils::FullName(parsed_name.job, parsed_name.replica,
+ parsed_name.task, parsed_name.type,
+ parsed_name.id);
+ DeviceAttributes attributes(underlying->attributes());
+ attributes.set_name(name);
+ return new RenamedDevice(underlying, attributes, owns_underlying);
+}
+
+RenamedDevice::RenamedDevice(Device* underlying,
+ const DeviceAttributes& attributes,
+ bool owns_underlying)
+ : Device(underlying->env(), attributes),
+ underlying_(underlying),
+ owns_underlying_(owns_underlying) {}
+
+RenamedDevice::~RenamedDevice() {
+ if (owns_underlying_) {
+ delete underlying_;
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/renamed_device.h b/tensorflow/core/common_runtime/renamed_device.h
new file mode 100644
index 0000000000..0158e18ced
--- /dev/null
+++ b/tensorflow/core/common_runtime/renamed_device.h
@@ -0,0 +1,119 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+
+// Wraps a device with a new name, delegating work to the wrapped device.
+//
+// This class is used to wrap local devices when using clusterspec propagation
+// where the name of a particular device may change in the context of a given
+// session.
+class RenamedDevice : public Device {
+ public:
+ static Device* NewRenamedDevice(const string& new_base, Device* underlying,
+ bool owns_underlying);
+ ~RenamedDevice() override;
+
+ // Below are virtual methods defined on DeviceBase
+ bool RequiresRecordingAccessedTensors() const override {
+ return underlying_->RequiresRecordingAccessedTensors();
+ }
+
+ const CpuWorkerThreads* tensorflow_cpu_worker_threads() const override {
+ return underlying_->tensorflow_cpu_worker_threads();
+ }
+
+ const GpuDeviceInfo* tensorflow_gpu_device_info() const override {
+ return underlying_->tensorflow_gpu_device_info();
+ }
+
+ Allocator* GetAllocator(AllocatorAttributes attr) override {
+ return underlying_->GetAllocator(attr);
+ }
+
+ Allocator* GetStepAllocator(AllocatorAttributes attr,
+ ResourceMgr* step_resource_manager) override {
+ return underlying_->GetStepAllocator(attr, step_resource_manager);
+ }
+
+ const Eigen::ThreadPoolDevice* eigen_cpu_device() override {
+ return underlying_->eigen_cpu_device();
+ }
+
+#ifdef TENSORFLOW_USE_SYCL
+ const Eigen::SyclDevice* eigen_sycl_device() const override {
+ return underlying_->eigen_sycl_device();
+ }
+#endif
+
+ PerOpGpuDevice* MakeGpuDevice() override {
+ return underlying_->MakeGpuDevice();
+ }
+
+ void ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device,
+ DeviceContext* dc, Allocator* allocator) override {
+ underlying_->ReinitializeGpuDevice(context, device, dc, allocator);
+ }
+
+ Status MakeTensorFromProto(const TensorProto& tensor_proto,
+ const AllocatorAttributes alloc_attrs,
+ Tensor* tensor) override {
+ return underlying_->MakeTensorFromProto(tensor_proto, alloc_attrs, tensor);
+ }
+
+ // Below are virtual methods defined on Device
+
+ void Compute(OpKernel* op_kernel, OpKernelContext* context) override {
+ underlying_->Compute(op_kernel, context);
+ }
+
+ void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
+ AsyncOpKernel::DoneCallback done) override {
+ underlying_->ComputeAsync(op_kernel, context, std::move(done));
+ }
+
+ void ConsumeListOfAccessedTensors(
+ DeviceContext* context, const TensorReferenceVector& tensors) override {
+ underlying_->ConsumeListOfAccessedTensors(context, tensors);
+ }
+
+ Status Sync() override { return underlying_->Sync(); }
+
+ Status MaybeRewriteGraph(const FunctionDefLibrary& library,
+ std::unique_ptr<Graph>* graph) override {
+ return underlying_->MaybeRewriteGraph(library, graph);
+ }
+
+ Status FillContextMap(const Graph* graph,
+ DeviceContextMap* device_context_map) override {
+ return underlying_->FillContextMap(graph, device_context_map);
+ }
+
+ private:
+ RenamedDevice(Device* underlying, const DeviceAttributes& attributes,
+ bool owns_underlying);
+ Device* const underlying_;
+ const bool owns_underlying_;
+};
+
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_
diff --git a/tensorflow/core/common_runtime/simple_placer_test.cc b/tensorflow/core/common_runtime/simple_placer_test.cc
index bd84417b10..24f27af5f1 100644
--- a/tensorflow/core/common_runtime/simple_placer_test.cc
+++ b/tensorflow/core/common_runtime/simple_placer_test.cc
@@ -66,7 +66,7 @@ class DummyOp : public OpKernel {
class FakeDevice : public Device {
private:
explicit FakeDevice(const DeviceAttributes& device_attributes)
- : Device(nullptr, device_attributes, nullptr) {}
+ : Device(nullptr, device_attributes) {}
public:
Status Sync() override { return errors::Unimplemented("FakeDevice::Sync()"); }
diff --git a/tensorflow/core/common_runtime/threadpool_device.cc b/tensorflow/core/common_runtime/threadpool_device.cc
index 60348e885f..f5f8aab694 100644
--- a/tensorflow/core/common_runtime/threadpool_device.cc
+++ b/tensorflow/core/common_runtime/threadpool_device.cc
@@ -38,10 +38,8 @@ ThreadPoolDevice::ThreadPoolDevice(const SessionOptions& options,
const string& name, Bytes memory_limit,
const DeviceLocality& locality,
Allocator* allocator)
- : LocalDevice(options,
- Device::BuildDeviceAttributes(name, DEVICE_CPU, memory_limit,
- locality),
- allocator),
+ : LocalDevice(options, Device::BuildDeviceAttributes(
+ name, DEVICE_CPU, memory_limit, locality)),
allocator_(allocator) {}
ThreadPoolDevice::~ThreadPoolDevice() {}
diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD
index 0f5eb0cb32..d2a828f39f 100644
--- a/tensorflow/core/distributed_runtime/BUILD
+++ b/tensorflow/core/distributed_runtime/BUILD
@@ -77,7 +77,6 @@ cc_library(
],
deps = [
":graph_mgr",
- ":rendezvous_mgr_interface",
":worker_cache",
"//tensorflow/core:master_proto_cc",
"//tensorflow/core:protos_all_cc",
@@ -92,9 +91,9 @@ cc_library(
deps = [
":graph_mgr",
":worker_session",
+ "//tensorflow/core:core_cpu_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
- "//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
],
)
@@ -237,6 +236,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:master_proto_cc",
+ "//tensorflow/core:protos_all_cc",
"//tensorflow/core:worker_proto_cc",
],
)
diff --git a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc
index 5863727f19..e68aea46ec 100644
--- a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc
+++ b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc
@@ -35,9 +35,8 @@ limitations under the License.
namespace tensorflow {
-BaseRendezvousMgr::BaseRendezvousMgr(const WorkerEnv* worker_env,
- const string& worker_name)
- : worker_env_(worker_env), worker_name_(worker_name) {}
+BaseRendezvousMgr::BaseRendezvousMgr(const WorkerEnv* worker_env)
+ : worker_env_(worker_env) {}
BaseRendezvousMgr::~BaseRendezvousMgr() {
for (auto& p : table_) {
@@ -47,7 +46,7 @@ BaseRendezvousMgr::~BaseRendezvousMgr() {
}
}
-Rendezvous* BaseRendezvousMgr::Find(int64 step_id) {
+RemoteRendezvous* BaseRendezvousMgr::Find(int64 step_id) {
return FindOrCreate(step_id);
}
@@ -55,7 +54,7 @@ BaseRemoteRendezvous* BaseRendezvousMgr::FindOrCreate(int64 step_id) {
mutex_lock l(mu_);
Table::iterator iter = table_.find(step_id);
if (iter == table_.end()) {
- auto rr = Create(step_id, worker_env_, worker_name_);
+ auto rr = Create(step_id, worker_env_);
iter = table_.insert({step_id, rr}).first;
}
iter->second->Ref();
@@ -128,14 +127,12 @@ void BaseRendezvousMgr::CleanupAll() {
}
}
-BaseRemoteRendezvous::BaseRemoteRendezvous(const WorkerEnv* env,
- const string& worker_name,
- int64 step_id,
+BaseRemoteRendezvous::BaseRemoteRendezvous(const WorkerEnv* env, int64 step_id,
bool tolerate_dup_recv)
: env_(env),
- worker_name_(worker_name),
step_id_(step_id),
- local_(NewLocalRendezvous(tolerate_dup_recv)) {}
+ local_(NewLocalRendezvous(tolerate_dup_recv)),
+ session_(nullptr) {}
BaseRemoteRendezvous::~BaseRemoteRendezvous() {
CHECK(active_.empty());
@@ -150,6 +147,41 @@ static bool IsLocalDevice(const string& worker_name,
return device_name.starts_with(worker_name);
}
+Status BaseRemoteRendezvous::Initialize(WorkerSession* session) {
+ CHECK_NE(session, nullptr) << "session must not be null!";
+ std::vector<DeferredCall> deferred_calls;
+ {
+ mutex_lock l(mu_);
+ if (session_ != nullptr) {
+ if (session_->worker_name == session->worker_name) {
+ LOG(INFO) << "Skipping rendezvous re-initialization.";
+ return Status::OK();
+ }
+ Status s = errors::Internal(
+ "Double init! Worker names would have changed from: ",
+ session_->worker_name, " -> ", session->worker_name);
+ LOG(WARNING) << s;
+ return s;
+ }
+ session_ = session;
+ std::swap(deferred_calls, deferred_calls_);
+ }
+ for (DeferredCall& call : deferred_calls) {
+ RecvLocalAsyncInternal(call.parsed, std::move(call.done));
+ }
+ return Status::OK();
+}
+
+WorkerSession* BaseRemoteRendezvous::session() {
+ mutex_lock l(mu_);
+ return session_;
+}
+
+bool BaseRemoteRendezvous::is_initialized() {
+ mutex_lock l(mu_);
+ return is_initialized_locked();
+}
+
Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed,
const Rendezvous::Args& args,
const Tensor& val, const bool is_dead) {
@@ -157,10 +189,12 @@ Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed,
{
mutex_lock l(mu_);
if (!status_.ok()) return status_;
- }
- if (!IsLocalDevice(worker_name_, parsed.src_device)) {
- return errors::InvalidArgument("Invalid rendezvous key (src): ",
- parsed.FullKey(), " @ ", worker_name_);
+ DCHECK(is_initialized_locked());
+ if (!IsLocalDevice(session_->worker_name, parsed.src_device)) {
+ return errors::InvalidArgument(
+ "Invalid rendezvous key (src): ", parsed.FullKey(), " @ ",
+ session_->worker_name);
+ }
}
// Buffers "val" and "device_context" in local_.
return local_->Send(parsed, args, val, is_dead);
@@ -168,17 +202,24 @@ Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed,
Status BaseRemoteRendezvous::ValidateDevices(const ParsedKey& parsed,
bool is_src) {
+ // Cache session pointer to avoid repeatedly taking & releasing the lock
+ // (e.g. calling session())
+ WorkerSession* sess = nullptr;
{
mutex_lock l(mu_);
if (!status_.ok()) return status_;
+ if (!is_initialized_locked()) {
+ return errors::Internal("ValidateDevices called before initialization.");
+ }
+ sess = session_;
}
- if (is_src && !IsLocalDevice(worker_name_, parsed.src_device)) {
+ if (is_src && !IsLocalDevice(sess->worker_name, parsed.src_device)) {
return errors::InvalidArgument("Invalid rendezvous key (src): ",
- parsed.FullKey(), " @ ", worker_name_);
+ parsed.FullKey(), " @ ", sess->worker_name);
}
- if (!is_src && !IsLocalDevice(worker_name_, parsed.dst_device)) {
+ if (!is_src && !IsLocalDevice(sess->worker_name, parsed.dst_device)) {
return errors::InvalidArgument("Invalid rendezvous key (dst): ",
- parsed.FullKey(), " @ ", worker_name_);
+ parsed.FullKey(), " @ ", sess->worker_name);
}
return Status::OK();
}
@@ -244,6 +285,7 @@ void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed,
const Rendezvous::Args& recv_args,
DoneCallback done) {
VLOG(1) << "RemoteRendezvous Recv " << this << " " << parsed.FullKey();
+ CHECK(is_initialized()) << "RecvAsync called when uninitialized.";
Status s = ValidateDevices(parsed, false /*!is_src*/);
if (!s.ok()) {
done(s, Args(), recv_args, Tensor(), false);
@@ -280,6 +322,26 @@ void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed,
void BaseRemoteRendezvous::RecvLocalAsync(const ParsedKey& parsed,
DoneCallback done) {
+ {
+ mutex_lock l(mu_);
+ if (!is_initialized_locked()) {
+ // RecvLocalAsync can be called (due to an incoming RecvTensor RPC from a
+ // remote worker) before the RunStep (or PartialRunStep) RPC from the
+ // master arrives. RecvLocalAsync thus buffers the arguments until after
+ // the RemoteRendezvous is Initialize()'d, when it completes the
+ // rendezvous logic. At some point after Initialize() is called, a Tensor
+ // is produced locally that will then be sent in response to the incoming
+ // RPC.
+ DeferredCall call(parsed, std::move(done));
+ deferred_calls_.push_back(call);
+ return;
+ }
+ }
+ RecvLocalAsyncInternal(parsed, std::move(done));
+}
+
+void BaseRemoteRendezvous::RecvLocalAsyncInternal(const ParsedKey& parsed,
+ DoneCallback done) {
Status s = ValidateDevices(parsed, true /* is_src */);
if (!s.ok()) {
done(s, Args(), Args(), Tensor(), false);
@@ -318,4 +380,8 @@ void BaseRemoteRendezvous::DeregisterCall(BaseRecvTensorCall* call) {
active_.erase(call);
}
+BaseRemoteRendezvous::DeferredCall::DeferredCall(const ParsedKey& parsed,
+ DoneCallback done)
+ : parsed(parsed), done(std::move(done)) {}
+
} // end namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h
index 447a75913d..b252f45fe9 100644
--- a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h
+++ b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h
@@ -59,15 +59,17 @@ class BaseRecvTensorCall;
// RendezvousMgr must have keys generated by Rendezvous::CreateKey().
class BaseRendezvousMgr : public RendezvousMgrInterface {
public:
- explicit BaseRendezvousMgr(const WorkerEnv* worker_env,
- const string& worker_name);
+ explicit BaseRendezvousMgr(const WorkerEnv* worker_env);
~BaseRendezvousMgr() override;
// Returns Rendezvous supporting send and recv among workers in the
// "step_id". The caller takes ownership of one reference on the
// returned Rendezvous instance.
- Rendezvous* Find(int64 step_id) override;
+ //
+ // Note: the caller must guarantee to eventually call Initialize on the
+ // returned RemoteRendezvous
+ RemoteRendezvous* Find(int64 step_id) override;
// Finds the local rendezvous instance for the "step_id". Runs
// "done" when the tensor for "key" is produced or an error occurs.
@@ -91,8 +93,7 @@ class BaseRendezvousMgr : public RendezvousMgrInterface {
protected:
virtual BaseRemoteRendezvous* Create(int64 step_id,
- const WorkerEnv* worker_env,
- const string& worker_name) = 0;
+ const WorkerEnv* worker_env) = 0;
private:
// Maps step_id to rendezvous.
@@ -100,7 +101,6 @@ class BaseRendezvousMgr : public RendezvousMgrInterface {
// Not owned.
const WorkerEnv* const worker_env_;
- const string worker_name_;
mutex mu_;
Table table_ GUARDED_BY(mu_);
@@ -116,10 +116,13 @@ class BaseRendezvousMgr : public RendezvousMgrInterface {
// Buffering of Tensor values is delegated to a "local" Rendezvous
// obtained from NewLocalRendezvous(). This class just adds
// functionality to coordinate with remote workers.
-class BaseRemoteRendezvous : public Rendezvous {
+class BaseRemoteRendezvous : public RemoteRendezvous {
public:
- BaseRemoteRendezvous(const WorkerEnv* env, const string& worker_name,
- int64 step_id, bool tolerate_dup_recv);
+ BaseRemoteRendezvous(const WorkerEnv* env, int64 step_id,
+ bool tolerate_dup_recv);
+
+ // Upgrades the BaseRemoteRendezvous to full initialization.
+ Status Initialize(WorkerSession* session) override;
// Forwards to local_, where the Tensor "val" will be buffered and
// any waiting callback stored.
@@ -163,10 +166,13 @@ class BaseRemoteRendezvous : public Rendezvous {
// Removes "call" from active_ if "call" is in active_.
void DeregisterCall(BaseRecvTensorCall* call);
+ WorkerSession* session();
+
+ bool is_initialized();
+
~BaseRemoteRendezvous() override;
const WorkerEnv* const env_; // Not owned.
- const string worker_name_;
const int64 step_id_;
private:
@@ -176,10 +182,24 @@ class BaseRemoteRendezvous : public Rendezvous {
// Status given by StartAbort() if any.
Status status_ GUARDED_BY(mu_);
+ WorkerSession* session_ GUARDED_BY(mu_); // Not owned.
+
+ // Data structures to handle calls when partially initialized.
+ struct DeferredCall {
+ const ParsedKey parsed;
+ DoneCallback done;
+
+ DeferredCall(const ParsedKey& parsed, DoneCallback done);
+ };
+ std::vector<DeferredCall> deferred_calls_ GUARDED_BY(mu_);
// Active outstanding RecvTensor calls.
gtl::FlatSet<BaseRecvTensorCall*> active_ GUARDED_BY(mu_);
+ bool is_initialized_locked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ return session_ != nullptr;
+ }
+
// If "is_src" is true, checks that the rendezvous key "parsed"'s
// source is in this process. If "is_src" is false, checks that the
// rendezvous key "parsed"'s destination is in this process.
@@ -194,6 +214,9 @@ class BaseRemoteRendezvous : public Rendezvous {
const Rendezvous::Args& out_args, const Tensor& in,
Tensor* out, StatusCallback done);
+ // Must be called only if fully initialized.
+ void RecvLocalAsyncInternal(const ParsedKey& parsed, DoneCallback done);
+
TF_DISALLOW_COPY_AND_ASSIGN(BaseRemoteRendezvous);
};
diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc
index ce7ce372e8..5bde771e8d 100644
--- a/tensorflow/core/distributed_runtime/graph_mgr.cc
+++ b/tensorflow/core/distributed_runtime/graph_mgr.cc
@@ -46,10 +46,8 @@ limitations under the License.
namespace tensorflow {
-GraphMgr::GraphMgr(const WorkerEnv* worker_env,
- RendezvousMgrInterface* rendezvous_mgr)
- : worker_env_(worker_env), rendezvous_mgr_(rendezvous_mgr), table_(5) {
- CHECK(rendezvous_mgr) << "Rendezvous mgr was null";
+GraphMgr::GraphMgr(const WorkerEnv* worker_env, DeviceMgr* device_mgr)
+ : worker_env_(worker_env), device_mgr_(device_mgr), table_(5) {
// The default value of sync_on_finish will be flipped soon and this
// environment variable will be removed as well.
Status status =
@@ -148,7 +146,7 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
};
popts.get_incarnation = [this](const string& name) -> int64 {
Device* device = nullptr;
- Status s = worker_env_->device_mgr->LookupDevice(name, &device);
+ Status s = device_mgr_->LookupDevice(name, &device);
if (s.ok()) {
return device->attributes().incarnation();
} else {
@@ -193,8 +191,7 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
ExecutionUnit* unit = &(item->units.back());
// Find the device.
- Status s =
- worker_env_->device_mgr->LookupDevice(device_name, &unit->device);
+ Status s = device_mgr_->LookupDevice(device_name, &unit->device);
if (!s.ok()) {
// Remove the empty unit from the item as the item destructor wants all
// units to have valid devices.
@@ -214,7 +211,7 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
// Function library runtime.
unit->lib = NewFunctionLibraryRuntime(
- worker_env_->device_mgr, worker_env_->env, unit->device,
+ device_mgr_, worker_env_->env, unit->device,
subgraph->versions().producer(), item->lib_def,
graph_options.optimizer_options());
@@ -419,14 +416,14 @@ void GraphMgr::RecvOutputsFromRendezvousAsync(Rendezvous* rendezvous,
}
Status GraphMgr::SendInputs(const int64 step_id, const NamedTensors& in) {
- Rendezvous* rendezvous = rendezvous_mgr_->Find(step_id);
+ Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
Status s = SendInputsToRendezvous(rendezvous, in);
rendezvous->Unref();
return s;
}
Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) {
- Rendezvous* rendezvous = rendezvous_mgr_->Find(step_id);
+ Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
Status s = RecvOutputsFromRendezvous(rendezvous, out);
rendezvous->Unref();
return s;
@@ -434,7 +431,7 @@ Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) {
void GraphMgr::RecvOutputsAsync(const int64 step_id, NamedTensors* out,
StatusCallback done) {
- Rendezvous* rendezvous = rendezvous_mgr_->Find(step_id);
+ Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
RecvOutputsFromRendezvousAsync(rendezvous, out,
[done, rendezvous](const Status s) {
rendezvous->Unref();
@@ -443,7 +440,8 @@ void GraphMgr::RecvOutputsAsync(const int64 step_id, NamedTensors* out,
}
void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
- const ExecutorOpts& opts,
+ WorkerSession* session,
+ const ExecutorOpts& /*opts*/,
StepStatsCollector* collector,
CostGraphDef* cost_graph,
CancellationManager* cancellation_manager,
@@ -464,10 +462,14 @@ void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
return;
}
- Rendezvous* rendezvous = rendezvous_mgr_->Find(step_id);
+ RemoteRendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
+ Status s = rendezvous->Initialize(session);
// Sends values specified by the caller.
- Status s = SendInputsToRendezvous(rendezvous, in);
+ if (s.ok()) {
+ s = SendInputsToRendezvous(rendezvous, in);
+ }
+
if (!s.ok()) {
done(s);
item->Unref();
@@ -492,10 +494,9 @@ void GraphMgr::StartParallelExecutors(const string& handle, int64 step_id,
StatusCallback done) {
const int num_units = item->units.size();
CHECK_GE(num_units, 1);
- ScopedStepContainer* step_container =
- new ScopedStepContainer(step_id, [this](const string& name) {
- worker_env_->device_mgr->ClearContainers({name});
- });
+ ScopedStepContainer* step_container = new ScopedStepContainer(
+ step_id,
+ [this](const string& name) { device_mgr_->ClearContainers({name}); });
// NOTE: Transfer one ref of rendezvous and item.
ExecutorBarrier* barrier =
new ExecutorBarrier(num_units, rendezvous,
diff --git a/tensorflow/core/distributed_runtime/graph_mgr.h b/tensorflow/core/distributed_runtime/graph_mgr.h
index 349af6c54e..50391f47e4 100644
--- a/tensorflow/core/distributed_runtime/graph_mgr.h
+++ b/tensorflow/core/distributed_runtime/graph_mgr.h
@@ -37,6 +37,8 @@ namespace tensorflow {
class ExecutorOpts;
class StepStatsCollector;
class RendezvousMgrInterface;
+class DeviceMgr;
+struct WorkerSession;
// GraphMgr keeps track of a set of graphs that are registered with a
// TensorFlow worker. Each registered graph is identified by a handle
@@ -62,8 +64,7 @@ class RendezvousMgrInterface;
// EXPECT_EQ(out["c"], Tensor({4, 6}));
class GraphMgr {
public:
- explicit GraphMgr(const WorkerEnv* worker_env,
- RendezvousMgrInterface* rendezvous_mgr);
+ explicit GraphMgr(const WorkerEnv* worker_env, DeviceMgr* device_mgr);
~GraphMgr();
// Registers a graph. Fills in "handle"
@@ -78,8 +79,8 @@ class GraphMgr {
typedef std::map<string, Tensor> NamedTensors;
typedef std::function<void(const Status&)> StatusCallback;
void ExecuteAsync(const string& handle, const int64 step_id,
- const ExecutorOpts& opts, StepStatsCollector* collector,
- CostGraphDef* cost_graph,
+ WorkerSession* session, const ExecutorOpts& opts,
+ StepStatsCollector* collector, CostGraphDef* cost_graph,
CancellationManager* cancellation_manager,
const NamedTensors& in, StatusCallback done);
@@ -131,7 +132,7 @@ class GraphMgr {
};
const WorkerEnv* worker_env_; // Not owned.
- RendezvousMgrInterface* rendezvous_mgr_; // Not owned.
+ DeviceMgr* device_mgr_;
CostModelManager cost_model_manager_;
diff --git a/tensorflow/core/distributed_runtime/master.cc b/tensorflow/core/distributed_runtime/master.cc
index b4adee3bf6..e860c99d95 100644
--- a/tensorflow/core/distributed_runtime/master.cc
+++ b/tensorflow/core/distributed_runtime/master.cc
@@ -34,6 +34,7 @@ limitations under the License.
#include <unordered_set>
#include <vector>
+#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/distributed_runtime/remote_device.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
@@ -48,12 +49,17 @@ limitations under the License.
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/protobuf/cluster.pb.h"
#include "tensorflow/core/protobuf/master.pb.h"
#include "tensorflow/core/protobuf/worker.pb.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
+namespace {
+const char* const kGrpcProtocol = "grpc://";
+} // namespace
+
Master::Master(MasterEnv* env, double session_gc_seconds)
: env_(env),
last_1000_steps_(1000),
@@ -290,25 +296,122 @@ void Master::CreateSession(const CreateSessionRequest* req,
CreateSessionResponse* resp, MyClosure done) {
SchedClosure([this, req, resp, done]() {
Status status;
+ WorkerCacheFactoryOptions worker_cache_factory_options;
+ string grpc_protocol("grpc");
+ worker_cache_factory_options.protocol = &grpc_protocol;
auto call_done = gtl::MakeCleanup([&status, &done] { done(status); });
status = ValidateExternalGraphDefSyntax(req->graph_def());
if (!status.ok()) return;
- // Ping all the workers and build the list of devices that the
- // session will use.
+
+ // The following 4 variables are set differently, depending on whether this
+ // session uses a client-provided clusterspec or not.
+ WorkerCacheInterface* worker_cache = nullptr;
+ // Note: worker_cache_ptr will be null except if this session is using a
+ // client-supplied ClusterDef (ClusterSpec propagation).
+ std::unique_ptr<WorkerCacheInterface> worker_cache_ptr;
+ std::unique_ptr<DeviceSet> device_set;
// TODO(saeta): Convert to std::make_unique when available.
std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devices(
new std::vector<std::unique_ptr<Device>>());
- status = DeviceFinder::GetRemoteDevices(req->config().device_filters(),
- env_, env_->worker_cache,
- remote_devices.get());
- if (!status.ok()) return;
+
+ if (req->config().has_cluster_def()) {
+ worker_cache_factory_options.cluster_def = &req->config().cluster_def();
+
+ // Set the server_def's job_name and task_index fields.
+ string normalized_string;
+ string grpc_protocol(kGrpcProtocol);
+ if (req->target().compare(0, grpc_protocol.length(), grpc_protocol) ==
+ 0) {
+ normalized_string =
+ req->target().substr(grpc_protocol.length(), string::npos);
+ } else {
+ normalized_string = req->target();
+ }
+ for (auto&& job : req->config().cluster_def().job()) {
+ for (auto&& task : job.tasks()) {
+ if (task.second == normalized_string) {
+ if (worker_cache_factory_options.job_name != nullptr) {
+ status = errors::InvalidArgument(
+ "Found multiple matching tasks that correspond to "
+ "to the master. Master target: '",
+ req->target(), "'. ClusterDef: ",
+ req->config().cluster_def().ShortDebugString());
+ LOG(ERROR) << status;
+ return;
+ }
+ if (env_->local_devices[0]->parsed_name().job == job.name() &&
+ env_->local_devices[0]->parsed_name().task == task.first) {
+ // TODO(b/37868888): Remove this limitation when resolved
+ status = errors::InvalidArgument(
+ "The ClusterSpec names the job and task index to be the same "
+ "names that were provided when the server booted. This is "
+ "currently not allowed. Job: ",
+ job.name(), ", task index: ", task.first);
+ return;
+ }
+ worker_cache_factory_options.job_name = &job.name();
+ worker_cache_factory_options.task_index = task.first;
+ }
+ }
+ }
+
+ // Create the worker cache from the computed server_def.
+ status = env_->worker_cache_factory(worker_cache_factory_options,
+ &worker_cache);
+ if (!status.ok()) return;
+ worker_cache_ptr = std::unique_ptr<WorkerCacheInterface>(worker_cache);
+ // Ping all the workers and build the list of devices that the
+ // session will use.
+ status =
+ DeviceFinder::GetRemoteDevices(req->config().device_filters(), env_,
+ worker_cache, remote_devices.get());
+ if (!status.ok()) return;
+ device_set.reset(new DeviceSet);
+ for (auto&& d : *remote_devices) {
+ device_set->AddDevice(d.get());
+ DeviceNameUtils::ParsedName name = d->parsed_name();
+ if (name.job == *worker_cache_factory_options.job_name &&
+ name.task == worker_cache_factory_options.task_index &&
+ name.type == "CPU") {
+ device_set->set_client_device(d.get());
+ }
+ }
+ } else {
+ worker_cache = env_->worker_cache;
+ // Ping all the workers and build the list of devices that the
+ // session will use.
+ status =
+ DeviceFinder::GetRemoteDevices(req->config().device_filters(), env_,
+ worker_cache, remote_devices.get());
+ if (!status.ok()) return;
+ device_set.reset(new DeviceSet);
+ for (auto&& d : *remote_devices) {
+ device_set->AddDevice(d.get());
+ }
+ int num_local_devices = 0;
+ for (Device* d : env_->local_devices) {
+ device_set->AddDevice(d);
+ if (num_local_devices == 0) {
+ // Uses the first local device as the client device.
+ device_set->set_client_device(d);
+ }
+ num_local_devices++;
+ }
+ }
+
+ CHECK(device_set->client_device());
+
SessionOptions options;
options.config = req->config();
- MasterSession* session =
- env_->master_session_factory(options, env_, std::move(remote_devices));
+
+ MasterSession* session = env_->master_session_factory(
+ options, env_, std::move(remote_devices), std::move(worker_cache_ptr),
+ std::move(device_set));
+
GraphDef* gdef =
const_cast<CreateSessionRequest*>(req)->mutable_graph_def();
- status = session->Create(gdef);
+
+ status = session->Create(gdef, worker_cache_factory_options);
if (!status.ok()) {
session->Close().IgnoreError();
session->Unref();
diff --git a/tensorflow/core/distributed_runtime/master_env.h b/tensorflow/core/distributed_runtime/master_env.h
index a155bd384d..bb548adda1 100644
--- a/tensorflow/core/distributed_runtime/master_env.h
+++ b/tensorflow/core/distributed_runtime/master_env.h
@@ -19,17 +19,41 @@ limitations under the License.
#include <functional>
#include <vector>
-#include "tensorflow/core/distributed_runtime/master_session.h"
+#include "tensorflow/core/protobuf/cluster.pb.h"
+#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
class Device;
+class DeviceSet;
class Env;
class MasterSession;
class OpRegistryInterface;
class WorkerCacheInterface;
+// Options passed to the worker_cache_factory function.
+struct WorkerCacheFactoryOptions {
+ const ClusterDef* cluster_def = nullptr;
+ const string* job_name = nullptr;
+ int task_index;
+ const string* protocol = nullptr;
+
+ WorkerCacheFactoryOptions() {}
+
+ // Construct from a ServerDef proto.
+ //
+ // Note: server_def must outlive WorkerCacheFactoryOptions!
+ WorkerCacheFactoryOptions(const ServerDef& server_def) {
+ if (server_def.has_cluster() && !server_def.job_name().empty()) {
+ cluster_def = &server_def.cluster();
+ job_name = &server_def.job_name();
+ task_index = server_def.task_index();
+ protocol = &server_def.protocol();
+ }
+ }
+};
+
// The master environment class, which holds a bag of pointers to
// per-master state.
//
@@ -57,8 +81,14 @@ struct MasterEnv {
// `MasterEnv*` is retained by the caller.
std::function<MasterSession*(
SessionOptions, MasterEnv*,
- std::unique_ptr<std::vector<std::unique_ptr<Device>>>)>
+ std::unique_ptr<std::vector<std::unique_ptr<Device>>>,
+ std::unique_ptr<WorkerCacheInterface>,
+ std::unique_ptr<DeviceSet> device_set)>
master_session_factory;
+
+ std::function<Status(const WorkerCacheFactoryOptions&,
+ WorkerCacheInterface**)>
+ worker_cache_factory;
};
} // end namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc
index 5257aea1e3..50c5d90fc9 100644
--- a/tensorflow/core/distributed_runtime/master_session.cc
+++ b/tensorflow/core/distributed_runtime/master_session.cc
@@ -36,11 +36,13 @@ limitations under the License.
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
@@ -528,6 +530,7 @@ Status MasterSession::ReffedClientGraph::RunPartitions(
c->req->set_is_partial(is_partial_);
c->req->set_is_last_partial_run(is_last_partial_run);
}
+ c->req->set_session_handle(session_handle_);
c->req->set_graph_handle(part.graph_handle);
c->req->set_step_id(step_id);
*c->req->mutable_exec_opts() = exec_opts;
@@ -871,6 +874,7 @@ void MasterSession::ReffedClientGraph::DeregisterPartitions() {
// The graph handle may be empty if we failed during partition registration.
if (!part.graph_handle.empty()) {
Call* c = new Call;
+ c->req.set_session_handle(session_handle_);
c->req.set_graph_handle(part.graph_handle);
// NOTE(mrry): We must capture `worker_cache_` since `this`
// could be deleted before the callback is called.
@@ -973,31 +977,25 @@ string BuildGraphOptionsString(const BuildGraphOptions& opts) {
MasterSession::MasterSession(
const SessionOptions& opt, const MasterEnv* env,
std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
+ std::unique_ptr<WorkerCacheInterface> worker_cache,
+ std::unique_ptr<DeviceSet> device_set,
StatsPublisherFactory stats_publisher_factory)
: session_opts_(opt),
env_(env),
handle_(strings::FpToString(random::New64())),
remote_devs_(std::move(remote_devs)),
+ worker_cache_(std::move(worker_cache)),
+ devices_(std::move(device_set)),
stats_publisher_factory_(std::move(stats_publisher_factory)),
graph_version_(0),
run_graphs_(5),
partial_run_graphs_(5) {
UpdateLastAccessTime();
+ CHECK(devices_) << "device_set was null!";
VLOG(1) << "Session " << handle_ << " #local " << env->local_devices.size()
<< " #remote " << remote_devs_->size();
- for (auto&& d : *remote_devs_) {
- devices_.AddDevice(d.get());
- }
- int num_local_devices = 0;
- for (Device* d : env->local_devices) {
- devices_.AddDevice(d);
- if (num_local_devices == 0) {
- // Uses the first local device as the client device.
- devices_.set_client_device(d);
- }
- num_local_devices++;
- }
+
LOG(INFO) << "Start master session " << handle_
<< " with config: " << std::endl
<< session_opts_.config.DebugString();
@@ -1012,7 +1010,8 @@ void MasterSession::UpdateLastAccessTime() {
last_access_time_usec_.store(Env::Default()->NowMicros());
}
-Status MasterSession::Create(GraphDef* graph_def) {
+Status MasterSession::Create(GraphDef* graph_def,
+ const WorkerCacheFactoryOptions& options) {
if (session_opts_.config.graph_options().place_pruned_graph()) {
// TODO(b/29900832): Fix this or remove the option.
LOG(WARNING) << "Distributed session does not support the "
@@ -1020,17 +1019,93 @@ Status MasterSession::Create(GraphDef* graph_def) {
session_opts_.config.mutable_graph_options()->set_place_pruned_graph(false);
}
- SimpleGraphExecutionStateOptions options;
- options.device_set = &devices_;
- options.session_options = &session_opts_;
+ SimpleGraphExecutionStateOptions execution_options;
+ execution_options.device_set = devices_.get();
+ execution_options.session_options = &session_opts_;
{
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(SimpleGraphExecutionState::MakeForBaseGraph(
- graph_def, options, &execution_state_));
+ graph_def, execution_options, &execution_state_));
+ }
+ if (options.cluster_def != nullptr) {
+ return CreateWorkerSessions(options);
}
return Status::OK();
}
+Status MasterSession::CreateWorkerSessions(
+ const WorkerCacheFactoryOptions& options) {
+ CHECK(worker_cache_) << "CreateWorkerSessions should be called only with "
+ << "dynamic cluster membership.";
+ std::vector<string> worker_names;
+ worker_cache_->ListWorkers(&worker_names);
+
+ struct WorkerGroup {
+ // The worker name. (Not owned.)
+ const string* name;
+
+ // The worker referenced by name. (Not owned.)
+ WorkerInterface* worker = nullptr;
+
+ // Request and responses used for a given worker.
+ CreateWorkerSessionRequest request;
+ CreateWorkerSessionResponse response;
+ Status status = Status::OK();
+ };
+ BlockingCounter done(worker_names.size());
+ std::vector<WorkerGroup> workers(worker_names.size());
+
+ // Release the workers.
+ auto cleanup = gtl::MakeCleanup([this, &workers] {
+ for (auto&& worker_group : workers) {
+ if (worker_group.worker != nullptr) {
+ worker_cache_->ReleaseWorker(*worker_group.name, worker_group.worker);
+ }
+ }
+ });
+
+ Status status = Status::OK();
+ // Create all the workers & kick off the computations.
+ for (size_t i = 0; i < worker_names.size(); ++i) {
+ workers[i].name = &worker_names[i];
+ workers[i].worker = worker_cache_->CreateWorker(worker_names[i]);
+ workers[i].request.set_session_handle(handle_);
+ *workers[i].request.mutable_server_def()->mutable_cluster() =
+ *options.cluster_def;
+ workers[i].request.mutable_server_def()->set_protocol(*options.protocol);
+
+ DeviceNameUtils::ParsedName name;
+ if (!DeviceNameUtils::ParseFullName(worker_names[i], &name)) {
+ status = errors::Internal("Could not parse name ", worker_names[i]);
+ LOG(WARNING) << status;
+ return status;
+ }
+ if (!name.has_job || !name.has_task) {
+ status = errors::Internal("Incomplete worker name ", worker_names[i]);
+ LOG(WARNING) << status;
+ return status;
+ }
+
+ workers[i].request.mutable_server_def()->set_job_name(name.job);
+ workers[i].request.mutable_server_def()->set_task_index(name.task);
+ }
+
+ for (size_t i = 0; i < worker_names.size(); ++i) {
+ auto cb = [i, &workers, &done](const Status& s) {
+ workers[i].status = s;
+ done.DecrementCount();
+ };
+ workers[i].worker->CreateWorkerSessionAsync(&workers[i].request,
+ &workers[i].response, cb);
+ }
+
+ done.Wait();
+ for (size_t i = 0; i < workers.size(); ++i) {
+ status.Update(workers[i].status);
+ }
+ return status;
+}
+
Status MasterSession::Extend(const ExtendSessionRequest* req,
ExtendSessionResponse* resp) {
UpdateLastAccessTime();
@@ -1060,6 +1135,13 @@ Status MasterSession::Extend(const ExtendSessionRequest* req,
return Status::OK();
}
+WorkerCacheInterface* MasterSession::get_worker_cache() const {
+ if (worker_cache_) {
+ return worker_cache_.get();
+ }
+ return env_->worker_cache;
+}
+
Status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count,
ReffedClientGraph** rcg, bool is_partial) {
const uint64 hash = HashBuildGraphOptions(opts);
@@ -1083,11 +1165,11 @@ Status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count,
<< "\n";
std::unique_ptr<SimpleClientGraph> client_graph;
TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph));
+ WorkerCacheInterface* worker_cache = get_worker_cache();
auto entry = new ReffedClientGraph(
handle_, opts, std::move(client_graph), session_opts_,
stats_publisher_factory_, execution_state_.get(), is_partial,
- env_->worker_cache);
-
+ worker_cache);
iter = m->insert({hash, entry}).first;
VLOG(1) << "Preparing to execute new graph";
}
@@ -1162,6 +1244,8 @@ Status MasterSession::Run(CallOptions* opts, const RunStepRequestWrapper& req,
return errors::FailedPrecondition("Session is closed.");
}
++num_running_;
+ // Note: all code paths must eventually call MarkRunCompletion()
+ // in order to appropriate decrement the num_running_ counter.
}
Status status;
if (!req.partial_run_handle().empty()) {
@@ -1169,16 +1253,18 @@ Status MasterSession::Run(CallOptions* opts, const RunStepRequestWrapper& req,
} else {
status = DoRunWithLocalExecution(opts, req, resp);
}
- {
- mutex_lock l(mu_);
- --num_running_;
- if (num_running_ == 0) {
- num_running_is_zero_.notify_all();
- }
- }
return status;
}
+// Decrements num_running_ and broadcasts if num_running_ is zero.
+void MasterSession::MarkRunCompletion() {
+ mutex_lock l(mu_);
+ --num_running_;
+ if (num_running_ == 0) {
+ num_running_is_zero_.notify_all();
+ }
+}
+
Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) {
// Registers subgraphs if haven't done so.
PartitionOptions popts;
@@ -1188,7 +1274,7 @@ Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) {
return strings::StrCat(prefix, "_S", next_node_id_++);
};
popts.get_incarnation = [this](const string& name) -> int64 {
- Device* d = devices_.FindDeviceByName(name);
+ Device* d = devices_->FindDeviceByName(name);
if (d == nullptr) {
return PartitionOptions::kIllegalIncarnation;
} else {
@@ -1223,6 +1309,7 @@ Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) {
Status MasterSession::DoPartialRun(CallOptions* opts,
const RunStepRequestWrapper& req,
MutableRunStepResponseWrapper* resp) {
+ auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); });
const string& prun_handle = req.partial_run_handle();
RunState* run_state = nullptr;
{
@@ -1321,12 +1408,14 @@ Status MasterSession::DoPartialRun(CallOptions* opts,
rcg->Ref();
rcg->ProcessStats(run_state->step_id, &run_state->pss, run_state->ph.get(),
req.options(), resp->mutable_metadata());
+ cleanup.release(); // MarkRunCompletion called in done closure.
rcg->CleanupPartitionsAsync(
run_state->step_id, [this, rcg, prun_handle](const Status& s) {
if (!s.ok()) {
LOG(ERROR) << "Cleanup partition error: " << s;
}
rcg->Unref();
+ MarkRunCompletion();
});
mutex_lock l(mu_);
partial_runs_.erase(prun_handle);
@@ -1368,10 +1457,10 @@ Status MasterSession::CreateDebuggerState(
Status MasterSession::DoRunWithLocalExecution(
CallOptions* opts, const RunStepRequestWrapper& req,
MutableRunStepResponseWrapper* resp) {
- VLOG(2) << "DoRunWithLocalExecution "
- << "req: " << req.DebugString();
+ VLOG(2) << "DoRunWithLocalExecution req: " << req.DebugString();
PerStepState pss;
pss.start_micros = Env::Default()->NowMicros();
+ auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); });
// Prepare.
BuildGraphOptions bgopts;
@@ -1438,11 +1527,13 @@ Status MasterSession::DoRunWithLocalExecution(
}
}
rcg->Ref();
- rcg->CleanupPartitionsAsync(step_id, [rcg](const Status& s) {
+ cleanup.release(); // MarkRunCompletion called in done closure.
+ rcg->CleanupPartitionsAsync(step_id, [this, rcg](const Status& s) {
if (!s.ok()) {
LOG(ERROR) << "Cleanup partition error: " << s;
}
rcg->Unref();
+ MarkRunCompletion();
});
return s;
}
diff --git a/tensorflow/core/distributed_runtime/master_session.h b/tensorflow/core/distributed_runtime/master_session.h
index d47125be99..3acc5bc5f0 100644
--- a/tensorflow/core/distributed_runtime/master_session.h
+++ b/tensorflow/core/distributed_runtime/master_session.h
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/call_options.h"
#include "tensorflow/core/distributed_runtime/master_env.h"
#include "tensorflow/core/distributed_runtime/message_wrappers.h"
+#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/master.pb.h"
@@ -49,13 +50,15 @@ class MasterSession : public core::RefCounted {
MasterSession(
const SessionOptions& options, const MasterEnv* env,
std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
+ std::unique_ptr<WorkerCacheInterface> worker_cache,
+ std::unique_ptr<DeviceSet> device_set,
StatsPublisherFactory stats_publisher_factory);
// Initialize the MasterSession for "def". Must be called before Extend(),
// Run(), or Close().
//
// After this method returns, `def` will no longer be valid.
- Status Create(GraphDef* def);
+ Status Create(GraphDef* def, const WorkerCacheFactoryOptions& options);
// Returns the session handle.
const string& handle() const { return handle_; }
@@ -107,8 +110,14 @@ class MasterSession : public core::RefCounted {
std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs_;
+ // The optional session-specific worker cluster.
+ // TODO(saeta): Convert to std::optional when available.
+ std::unique_ptr<WorkerCacheInterface> worker_cache_;
+ // Retrieves either worker_cache_ or the env_->worker_cache as appropriate.
+ WorkerCacheInterface* get_worker_cache() const;
+
// The device set used by this session.
- DeviceSet devices_;
+ std::unique_ptr<DeviceSet> devices_;
StatsPublisherFactory stats_publisher_factory_;
@@ -181,6 +190,13 @@ class MasterSession : public core::RefCounted {
// Private dtor. The client must call Close().
virtual ~MasterSession();
+ // Creates sessions on all workers.
+ //
+ // If this session is operating using the new ClusterSpec propagation behavior
+ // call this method in order to propagate the cluster membership to all
+ // workers.
+ Status CreateWorkerSessions(const WorkerCacheFactoryOptions& server_def);
+
Status StartStep(const BuildGraphOptions& opts, int64* count,
ReffedClientGraph** graph, bool is_partial);
void ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref,
@@ -190,6 +206,7 @@ class MasterSession : public core::RefCounted {
MutableRunStepResponseWrapper* resp);
Status DoPartialRun(CallOptions* opts, const RunStepRequestWrapper& req,
MutableRunStepResponseWrapper* resp);
+ void MarkRunCompletion();
void UpdateLastAccessTime();
Status BuildAndRegisterPartitions(ReffedClientGraph* rcg);
diff --git a/tensorflow/core/distributed_runtime/message_wrappers.cc b/tensorflow/core/distributed_runtime/message_wrappers.cc
index 7b58feb93c..b077975ea5 100644
--- a/tensorflow/core/distributed_runtime/message_wrappers.cc
+++ b/tensorflow/core/distributed_runtime/message_wrappers.cc
@@ -252,6 +252,14 @@ string ProtoRunStepRequest::DebugString() const {
const RunStepRequest& ProtoRunStepRequest::ToProto() const { return *request_; }
+const string& InMemoryRunGraphRequest::session_handle() const {
+ return session_handle_;
+}
+
+void InMemoryRunGraphRequest::set_session_handle(const string& handle) {
+ session_handle_ = handle;
+}
+
const string& InMemoryRunGraphRequest::graph_handle() const {
return graph_handle_;
}
@@ -320,6 +328,7 @@ void InMemoryRunGraphRequest::set_is_last_partial_run(
const RunGraphRequest& InMemoryRunGraphRequest::ToProto() const {
if (!proto_version_) {
proto_version_.reset(new RunGraphRequest);
+ proto_version_->set_session_handle(session_handle());
proto_version_->set_graph_handle(graph_handle());
proto_version_->set_step_id(step_id());
*proto_version_->mutable_exec_opts() = exec_opts();
@@ -337,6 +346,14 @@ const RunGraphRequest& InMemoryRunGraphRequest::ToProto() const {
return *proto_version_;
}
+const string& MutableProtoRunGraphRequest::session_handle() const {
+ return request_.session_handle();
+}
+
+void MutableProtoRunGraphRequest::set_session_handle(const string& handle) {
+ request_.set_session_handle(handle);
+}
+
const string& MutableProtoRunGraphRequest::graph_handle() const {
return request_.graph_handle();
}
@@ -423,6 +440,10 @@ const RunGraphRequest& MutableProtoRunGraphRequest::ToProto() const {
ProtoRunGraphRequest::ProtoRunGraphRequest(const RunGraphRequest* request)
: request_(request) {}
+const string& ProtoRunGraphRequest::session_handle() const {
+ return request_->session_handle();
+}
+
const string& ProtoRunGraphRequest::graph_handle() const {
return request_->graph_handle();
}
diff --git a/tensorflow/core/distributed_runtime/message_wrappers.h b/tensorflow/core/distributed_runtime/message_wrappers.h
index 02516eabb4..795a6add0e 100644
--- a/tensorflow/core/distributed_runtime/message_wrappers.h
+++ b/tensorflow/core/distributed_runtime/message_wrappers.h
@@ -223,6 +223,10 @@ class RunGraphRequestWrapper {
public:
virtual ~RunGraphRequestWrapper() {}
+ // The session handle used to register the graph. If empty, a single global
+ // namespace is used.
+ virtual const string& session_handle() const = 0;
+
// REQUIRED: graph_handle must be returned by a RegisterGraph call
// to the same WorkerService.
virtual const string& graph_handle() const = 0;
@@ -262,6 +266,7 @@ class RunGraphRequestWrapper {
// See `RunGraphRequestWrapper` above for a description of the fields.
class MutableRunGraphRequestWrapper : public RunGraphRequestWrapper {
public:
+ virtual void set_session_handle(const string& handle) = 0;
virtual void set_graph_handle(const string& handle) = 0;
virtual void set_step_id(int64 step_id) = 0;
virtual ExecutorOpts* mutable_exec_opts() = 0;
@@ -280,6 +285,7 @@ class MutableRunGraphRequestWrapper : public RunGraphRequestWrapper {
class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
public:
// RunGraphRequestWrapper methods.
+ const string& session_handle() const override;
const string& graph_handle() const override;
int64 step_id() const override;
const ExecutorOpts& exec_opts() const override;
@@ -293,6 +299,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
const RunGraphRequest& ToProto() const override;
// MutableRunGraphRequestWrapper methods.
+ void set_session_handle(const string& handle) override;
void set_graph_handle(const string& handle) override;
void set_step_id(int64 step_id) override;
ExecutorOpts* mutable_exec_opts() override;
@@ -304,6 +311,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
void set_is_last_partial_run(bool is_last_partial_run) override;
private:
+ string session_handle_;
string graph_handle_;
int64 step_id_;
ExecutorOpts exec_opts_;
@@ -325,6 +333,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper {
public:
// RunGraphRequestWrapper methods.
+ const string& session_handle() const override;
const string& graph_handle() const override;
int64 step_id() const override;
const ExecutorOpts& exec_opts() const override;
@@ -338,6 +347,7 @@ class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper {
const RunGraphRequest& ToProto() const override;
// MutableRunGraphRequestWrapper methods.
+ void set_session_handle(const string& handle) override;
void set_graph_handle(const string& handle) override;
void set_step_id(int64 step_id) override;
ExecutorOpts* mutable_exec_opts() override;
@@ -357,6 +367,7 @@ class ProtoRunGraphRequest : public RunGraphRequestWrapper {
ProtoRunGraphRequest(const RunGraphRequest* request);
// RunGraphRequestWrapper methods.
+ const string& session_handle() const override;
const string& graph_handle() const override;
int64 step_id() const override;
const ExecutorOpts& exec_opts() const override;
diff --git a/tensorflow/core/distributed_runtime/remote_device.cc b/tensorflow/core/distributed_runtime/remote_device.cc
index 9632e9c439..91c1fb99fe 100644
--- a/tensorflow/core/distributed_runtime/remote_device.cc
+++ b/tensorflow/core/distributed_runtime/remote_device.cc
@@ -16,11 +16,13 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/remote_device.h"
#include <vector>
+
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/distributed_runtime/worker_interface.h"
#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/protobuf/worker.pb.h"
@@ -43,8 +45,7 @@ string GetLocalDeviceName(StringPiece fullname) {
class RemoteDevice : public Device {
public:
RemoteDevice(Env* env, const DeviceAttributes& da)
- : Device(env, da, nullptr),
- local_dev_name_(GetLocalDeviceName(da.name())) {}
+ : Device(env, da), local_dev_name_(GetLocalDeviceName(da.name())) {}
Status Sync() override { return Status::OK(); }
Allocator* GetAllocator(AllocatorAttributes attr) override { return nullptr; }
@@ -68,18 +69,50 @@ void NewRemoteDevices(Env* env, WorkerCacheInterface* worker_cache,
GetStatusResponse resp;
};
Call* call = new Call;
- auto cb = [env, worker_cache, worker_name, done, wi, call](const Status& s) {
+ auto cb = [env, worker_cache, worker_name, done, wi,
+ call](const Status& status) {
+ Status s = status;
std::vector<Device*> remote_devices;
+ auto cleanup = gtl::MakeCleanup(
+ [&worker_cache, &worker_name, &wi, &done, &remote_devices, &s, call] {
+ worker_cache->ReleaseWorker(worker_name, wi);
+ done(s, &remote_devices);
+ delete call;
+ });
if (s.ok()) {
+ DeviceNameUtils::ParsedName worker_name_parsed;
+ if (!DeviceNameUtils::ParseFullName(worker_name, &worker_name_parsed) ||
+ !worker_name_parsed.has_job || !worker_name_parsed.has_replica ||
+ !worker_name_parsed.has_task) {
+ s = errors::InvalidArgument("Could not parse worker name: ",
+ worker_name);
+ LOG(WARNING) << s;
+ return;
+ }
remote_devices.reserve(call->resp.device_attributes_size());
for (const DeviceAttributes& da : call->resp.device_attributes()) {
- auto d = new RemoteDevice(env, da);
- remote_devices.push_back(d);
+ DeviceNameUtils::ParsedName device_name_parsed;
+ CHECK(DeviceNameUtils::ParseFullName(da.name(), &device_name_parsed))
+ << "Device attribute name '" << da.name() << "' could not be "
+ << "parsed. Device Attribute: " << da.DebugString();
+ // Preserve the exact name, if possible.
+ // TODO(b/37868888): Simplify when legacy device name formats removed.
+ if (device_name_parsed.job == worker_name_parsed.job &&
+ device_name_parsed.replica == worker_name_parsed.replica &&
+ device_name_parsed.task == worker_name_parsed.task) {
+ auto d = new RemoteDevice(env, da);
+ remote_devices.push_back(d);
+ } else {
+ DeviceAttributes da_rewritten = da;
+ da_rewritten.set_name(DeviceNameUtils::FullName(
+ worker_name_parsed.job, worker_name_parsed.replica,
+ worker_name_parsed.task, device_name_parsed.type,
+ device_name_parsed.id));
+ auto d = new RemoteDevice(env, da_rewritten);
+ remote_devices.push_back(d);
+ }
}
}
- worker_cache->ReleaseWorker(worker_name, wi);
- done(s, &remote_devices);
- delete call;
};
wi->GetStatusAsync(&call->req, &call->resp, cb);
}
diff --git a/tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h b/tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h
index 04c1fc248e..43267d4362 100644
--- a/tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h
+++ b/tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h
@@ -25,6 +25,23 @@ limitations under the License.
namespace tensorflow {
+struct WorkerSession;
+
+// RemoteRendezvous follow a 2-part initialization. First the objects are
+// constructed. Eventually, they will be initialized. Clients of the
+// RendezvousMgrInterface must guarantee to call Initialize on the returned
+// RemoteRendezvous eventually.
+//
+// Partially initialized RemoteRendezvous must respect the Rendezvous interface
+// (i.e. Send() must never block), however implementations are not expected to
+// actually perform the underlying operations until after the RemoteRendezvous
+// has been Initialize'd.
+class RemoteRendezvous : public Rendezvous {
+ public:
+ // Fully construct the RemoteRendezvous.
+ virtual Status Initialize(WorkerSession* session) = 0;
+};
+
// RendezvousMgr keeps track of a set of local rendezvous instances.
// All tensors sent by this worker are buffered in a RendezvousMgr
// until the tensor is received. Each global unique "step_id"
@@ -51,7 +68,10 @@ class RendezvousMgrInterface {
// Returns Rendezvous supporting send and recv among workers in the
// "step_id". The caller takes ownership of one reference on the
// returned Rendezvous instance.
- virtual Rendezvous* Find(int64 step_id) = 0;
+ //
+ // Note: the caller must guarantee to eventually call Initialize on the
+ // returned RemoteRendezvous
+ virtual RemoteRendezvous* Find(int64 step_id) = 0;
// Finds the local rendezvous instance for the "step_id". Runs
// "done" when the tensor for "key" is produced or an error occurs.
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
index 7160962b16..3867dd1f4d 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
@@ -63,10 +63,8 @@ class NoReusePortOption : public ::grpc::ServerBuilderOption {
};
// static utility function
-RendezvousMgrInterface* NewRpcRendezvousMgr(
- const WorkerEnv* env, const string& worker_name,
- WorkerCacheInterface* worker_cache) {
- return new RpcRendezvousMgr(env, worker_name, worker_cache);
+RendezvousMgrInterface* NewRpcRendezvousMgr(const WorkerEnv* env) {
+ return new RpcRendezvousMgr(env);
}
} // namespace
@@ -84,6 +82,9 @@ GrpcServer::~GrpcServer() {
// TODO(mrry): Refactor the *Env classes so that it is less fiddly
// to destroy them.
+ // Shut down all outstanding rendezvous.
+ delete worker_env_.rendezvous_mgr;
+
// We must delete graph_mgr before device_mgr, due to shared
// ownership of OpKernels in the executors. (The graph_mgr will
// free all stateless OpKernels, and pass over borrowed stateful
@@ -91,8 +92,10 @@ GrpcServer::~GrpcServer() {
// OpSegments.)
if (worker_env_.session_mgr != nullptr) {
delete worker_env_.session_mgr; // Deletes graph_mgr's.
+ } else {
+ // Note: session_mgr's legacy_session_ deletes device_mgr now.
+ delete worker_env_.device_mgr;
}
- delete worker_env_.device_mgr;
// Do not delete (as these are not owned by the server):
// - master_env_.env
@@ -100,8 +103,9 @@ GrpcServer::~GrpcServer() {
// - worker_env_.compute_pool
}
-Status GrpcServer::Init(ServiceInitFunction service_func,
- RendezvousMgrCreationFunction rendevous_mgr_func) {
+Status GrpcServer::Init(
+ ServiceInitFunction service_func,
+ const RendezvousMgrCreationFunction& rendezvous_mgr_func) {
mutex_lock l(mu_);
CHECK_EQ(state_, NEW);
master_env_.env = env_;
@@ -117,7 +121,11 @@ Status GrpcServer::Init(ServiceInitFunction service_func,
"/task:", server_def_.task_index());
TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(sess_opts, name_prefix,
&master_env_.local_devices));
- worker_env_.device_mgr = new DeviceMgr(master_env_.local_devices);
+ worker_env_.local_devices = master_env_.local_devices;
+ worker_env_.device_mgr = new DeviceMgr(worker_env_.local_devices);
+ worker_env_.rendezvous_mgr = rendezvous_mgr_func == nullptr
+ ? new RpcRendezvousMgr(&worker_env_)
+ : rendezvous_mgr_func(&worker_env_);
string unused;
string default_worker_name;
if (!DeviceNameUtils::SplitDeviceName(master_env_.local_devices[0]->name(),
@@ -189,20 +197,18 @@ Status GrpcServer::Init(ServiceInitFunction service_func,
}
WorkerCacheInterface* worker_cache;
- TF_RETURN_IF_ERROR(WorkerCacheFactory(server_def_, &worker_cache));
+ WorkerCacheFactoryOptions worker_cache_factory_options(server_def_);
+ TF_RETURN_IF_ERROR(
+ WorkerCacheFactory(worker_cache_factory_options, &worker_cache));
CHECK_NE(nullptr, worker_cache);
// Set up worker environment.
- std::unique_ptr<RendezvousMgrInterface> rendezvous_mgr(
- rendevous_mgr_func == nullptr ?
- new RpcRendezvousMgr(&worker_env_, name_prefix, worker_cache) :
- rendevous_mgr_func(&worker_env_, name_prefix, worker_cache));
worker_env_.session_mgr = new SessionMgr(
&worker_env_, SessionMgr::WorkerNameFromServerDef(server_def_),
std::unique_ptr<WorkerCacheInterface>(worker_cache),
- std::move(rendezvous_mgr),
[this](const ServerDef& server_def, WorkerCacheInterface** worker_cache) {
- return WorkerCacheFactory(server_def, worker_cache);
+ WorkerCacheFactoryOptions options(server_def);
+ return WorkerCacheFactory(options, worker_cache);
});
worker_env_.compute_pool = ComputePool(sess_opts);
@@ -212,11 +218,19 @@ Status GrpcServer::Init(ServiceInitFunction service_func,
master_env_.master_session_factory =
[config](
SessionOptions options, const MasterEnv* env,
- std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs) {
+ std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
+ std::unique_ptr<WorkerCacheInterface> worker_cache,
+ std::unique_ptr<DeviceSet> device_set) {
options.config.MergeFrom(config);
return new MasterSession(options, env, std::move(remote_devs),
+ std::move(worker_cache), std::move(device_set),
CreateNoOpStatsPublisher);
};
+ master_env_.worker_cache_factory =
+ [this](const WorkerCacheFactoryOptions& options,
+ WorkerCacheInterface** worker_cache) {
+ return WorkerCacheFactory(options, worker_cache);
+ };
// Provide direct access to the master from in-process clients.
LocalMaster::Register(target(), master_impl_.get(),
@@ -225,13 +239,11 @@ Status GrpcServer::Init(ServiceInitFunction service_func,
return Status::OK();
}
-Status GrpcServer::Init() {
- return Init(nullptr, nullptr);
-}
+Status GrpcServer::Init() { return Init(nullptr, nullptr); }
-Status GrpcServer::ParseChannelSpec(const ServerDef& server_def,
+Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options,
GrpcChannelSpec* channel_spec) {
- for (const auto& job : server_def.cluster().job()) {
+ for (const auto& job : options.cluster_def->job()) {
std::map<int, string> host_ports;
for (const auto& task : job.tasks()) {
string& host_port = host_ports[task.first];
@@ -241,8 +253,7 @@ Status GrpcServer::ParseChannelSpec(const ServerDef& server_def,
task.first, "\": ", host_port, " and ",
task.second);
}
- if (job.name() == server_def.job_name() &&
- task.first == server_def.task_index()) {
+ if (job.name() == *options.job_name && task.first == options.task_index) {
host_port = strings::StrCat("localhost:", bound_port_);
} else {
host_port = task.second;
@@ -253,17 +264,26 @@ Status GrpcServer::ParseChannelSpec(const ServerDef& server_def,
return Status::OK();
}
-Status GrpcServer::WorkerCacheFactory(const ServerDef& server_def,
+Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
WorkerCacheInterface** worker_cache) {
- string name_prefix =
- strings::StrCat("/job:", server_def.job_name(), "/replica:0",
- "/task:", server_def.task_index());
+ if (options.job_name == nullptr || options.job_name->empty()) {
+ Status s = errors::InvalidArgument(
+ "The master (current machine) is not included in the provided "
+ "cluster_def. ",
+ options.cluster_def->DebugString());
+ LOG(WARNING) << s;
+ return s;
+ }
GrpcChannelSpec channel_spec;
- TF_RETURN_IF_ERROR(ParseChannelSpec(server_def, &channel_spec));
+ TF_RETURN_IF_ERROR(ParseChannelSpec(options, &channel_spec));
+
+ std::unique_ptr<GrpcChannelCache> channel_cache(
+ NewGrpcChannelCache(channel_spec, GetChannelCreationFunction()));
+
+ string name_prefix = strings::StrCat("/job:", *options.job_name, "/replica:0",
+ "/task:", options.task_index);
- std::unique_ptr<GrpcChannelCache> channel_cache(NewGrpcChannelCache(
- channel_spec, GetChannelCreationFunction(server_def)));
const string host_port = channel_cache->TranslateTask(name_prefix);
int requested_port;
@@ -349,8 +369,7 @@ std::shared_ptr<::grpc::ServerCredentials> GrpcServer::GetServerCredentials(
return ::grpc::InsecureServerCredentials();
}
-ChannelCreationFunction GrpcServer::GetChannelCreationFunction(
- const ServerDef& server_def) const {
+ChannelCreationFunction GrpcServer::GetChannelCreationFunction() const {
// We can do this because SparseGrpcChannelCache is robust to nullptr being
// returned by the channel creation function
return ConvertToChannelCreationFunction(NewHostPortGrpcChannel);
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
index 3b66291a9a..7b54bb84c8 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
@@ -37,9 +37,7 @@ class GrpcWorker;
class Master;
// function that creates a RendezvousMgr.
-typedef std::function<RendezvousMgrInterface*(
- const WorkerEnv*, const std::string& worker_name,
- WorkerCacheInterface* worker_cache)>
+typedef std::function<RendezvousMgrInterface*(const WorkerEnv*)>
RendezvousMgrCreationFunction;
// function that registers a service to the server. The service needs to
@@ -67,7 +65,7 @@ class GrpcServer : public ServerInterface {
protected:
Status Init(ServiceInitFunction service_func,
- RendezvousMgrCreationFunction rendezvous_mgr_func);
+ const RendezvousMgrCreationFunction& rendezvous_mgr_func);
Status Init();
@@ -75,17 +73,16 @@ class GrpcServer : public ServerInterface {
virtual std::shared_ptr<::grpc::ServerCredentials> GetServerCredentials(
const ServerDef& server_def) const;
- virtual ChannelCreationFunction GetChannelCreationFunction(
- const ServerDef& server_def) const;
+ virtual ChannelCreationFunction GetChannelCreationFunction() const;
virtual std::unique_ptr<Master> CreateMaster(MasterEnv* master_env);
// Creates a WorkerCacheInterface for a session.
- Status WorkerCacheFactory(const ServerDef& server_def,
+ Status WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
WorkerCacheInterface** worker_cache);
- // Parses a ServerDef into a GrpcChannelSpec.
- Status ParseChannelSpec(const ServerDef& server_def,
+ // Parses a WorkerCacheFactoryOptions into a GrpcChannelSpec.
+ Status ParseChannelSpec(const WorkerCacheFactoryOptions& options,
GrpcChannelSpec* channel_spec);
// Returns the port to which this server is bound.
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc
index 1aacef8a26..38d59d5bb5 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc
@@ -43,7 +43,7 @@ const size_t kSchemePrefixLength = strlen(kSchemePrefix);
/* static */
Status GrpcSession::Create(const SessionOptions& options,
std::unique_ptr<GrpcSession>* out_session) {
- std::unique_ptr<GrpcSession> ret(new GrpcSession(options));
+ std::unique_ptr<GrpcSession> session(new GrpcSession(options));
std::unique_ptr<MasterInterface> master;
// For testing, we enable the client to disable the use of the local
// master registry, so that the RPC stack is exercised.
@@ -56,8 +56,8 @@ Status GrpcSession::Create(const SessionOptions& options,
options.target.substr(kSchemePrefixLength), &master_channel));
master.reset(NewGrpcMaster(master_channel));
}
- ret->SetRemoteMaster(std::move(master));
- *out_session = std::move(ret);
+ session->SetRemoteMaster(std::move(master));
+ *out_session = std::move(session);
return Status::OK();
}
@@ -102,6 +102,7 @@ Status GrpcSession::CreateImpl(CallOptions* call_options,
CreateSessionRequest req;
*req.mutable_config() = options_.config;
*req.mutable_graph_def() = graph;
+ req.set_target(options_.target);
ReEncodeConsts(req.mutable_graph_def());
CreateSessionResponse resp;
Status s = master_->CreateSession(call_options, &req, &resp);
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc
index c11266587d..873ef8588f 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc
@@ -113,6 +113,7 @@ class GrpcWorkerService : public AsyncServiceInterface {
// completes, and we may decide to bound some of the request
// types.
ENQUEUE_REQUEST(GetStatus, false);
+ ENQUEUE_REQUEST(CreateWorkerSession, false);
ENQUEUE_REQUEST(CleanupAll, false);
ENQUEUE_REQUEST(RegisterGraph, false);
ENQUEUE_REQUEST(DeregisterGraph, false);
@@ -181,6 +182,16 @@ class GrpcWorkerService : public AsyncServiceInterface {
ENQUEUE_REQUEST(GetStatus, false);
}
+ void CreateWorkerSessionHandler(
+ WorkerCall<CreateWorkerSessionRequest, CreateWorkerSessionResponse>*
+ call) {
+ Schedule([this, call]() {
+ Status s = worker_->CreateWorkerSession(&call->request, &call->response);
+ call->SendResponse(ToGrpcStatus(s));
+ });
+ ENQUEUE_REQUEST(CreateWorkerSession, false);
+ }
+
void CleanupAllHandler(
WorkerCall<CleanupAllRequest, CleanupAllResponse>* call) {
Schedule([this, call]() {
@@ -298,7 +309,6 @@ void GrpcWorker::RecvTensorAsync(CallOptions* opts,
::grpc::ByteBuffer* response,
StatusCallback done) {
const int64 step_id = request->step_id();
- WorkerSession* session = env_->session_mgr->WorkerSessionForStepId(step_id);
const string& key = request->rendezvous_key();
TRACEPRINTF("RecvTensor: %lld %s", step_id, key.c_str());
Rendezvous::ParsedKey parsed;
@@ -317,7 +327,7 @@ void GrpcWorker::RecvTensorAsync(CallOptions* opts,
// of execution of the callback lambda body below, an RPC
// cancellation should abort the rendezvous.
opts->SetCancelCallback([this, step_id]() { AbortStep(step_id); });
- session->rendezvous_mgr->RecvLocalAsync(
+ env_->rendezvous_mgr->RecvLocalAsync(
step_id, parsed,
[opts, response, done, src_dev](const Status& status,
const Rendezvous::Args& send_args,
diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc
index 7518a289fd..8265100061 100644
--- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc
+++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc
@@ -38,9 +38,8 @@ namespace {
class RpcRemoteRendezvous : public BaseRemoteRendezvous {
public:
- RpcRemoteRendezvous(const WorkerEnv* env, const string& worker_name,
- WorkerCacheInterface* cache, int64 step_id)
- : BaseRemoteRendezvous(env, worker_name, step_id, false), cache_(cache) {}
+ RpcRemoteRendezvous(const WorkerEnv* env, int64 step_id)
+ : BaseRemoteRendezvous(env, step_id, false) {}
protected:
void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
@@ -50,7 +49,6 @@ class RpcRemoteRendezvous : public BaseRemoteRendezvous {
private:
~RpcRemoteRendezvous() override {}
- WorkerCacheInterface* const cache_; // Not owned.
TF_DISALLOW_COPY_AND_ASSIGN(RpcRemoteRendezvous);
};
@@ -204,75 +202,10 @@ static RpcRecvTensorFreeList* get_call_freelist() {
return call_freelist;
}
-// A private cache that wraps worker_cache and allows reuse of
-// WorkerInterface objects.
-class WorkerFreeListCache : public WorkerCacheInterface {
- public:
- explicit WorkerFreeListCache(WorkerCacheInterface* w) : wrapped_(w) {}
-
- ~WorkerFreeListCache() {
- for (auto p : workers_) {
- wrapped_->ReleaseWorker(p.first, p.second.worker);
- }
- }
-
- void ListWorkers(std::vector<string>* workers) const override {
- wrapped_->ListWorkers(workers);
- }
-
- WorkerInterface* CreateWorker(const string& target) override {
- mutex_lock l(mu_);
- auto p = workers_.find(target);
- if (p != workers_.end()) {
- return p->second.worker;
- }
- WorkerState state;
- state.worker = wrapped_->CreateWorker(target);
- if (state.worker != nullptr) {
- workers_.insert(std::make_pair(target, state));
- }
- return state.worker;
- }
-
- void ReleaseWorker(const string& target, WorkerInterface* worker) override {
- // TODO(jeff,sanjay): Should decrement ref-count when we implement eviction.
- }
-
- bool GetDeviceLocalityNonBlocking(const string& device,
- DeviceLocality* locality) override {
- return wrapped_->GetDeviceLocalityNonBlocking(device, locality);
- }
-
- void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality,
- StatusCallback done) override {
- wrapped_->GetDeviceLocalityAsync(device, locality, done);
- }
-
- void SetLogging(bool active) override { wrapped_->SetLogging(active); }
-
- void ClearLogs() override { wrapped_->ClearLogs(); }
-
- bool RetrieveLogs(int64 step_id, StepStats* ss) override {
- return wrapped_->RetrieveLogs(step_id, ss);
- }
-
- private:
- WorkerCacheInterface* wrapped_;
-
- // Information kept per created WorkerInterface.
- struct WorkerState {
- WorkerInterface* worker;
- // TODO(jeff,sanjay): Add reference count if we support eviction.
- };
-
- // TODO(jeff,sanjay): Eviction when the map becomes too big.
- mutex mu_;
- std::unordered_map<string, WorkerState> workers_ GUARDED_BY(mu_);
-};
-
void RpcRemoteRendezvous::RecvFromRemoteAsync(
const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args,
DoneCallback done) {
+ CHECK(is_initialized());
Status s;
// Prepare a RecvTensor call that can handle being aborted.
@@ -284,17 +217,21 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync(
s = errors::Internal(parsed.src_device,
" is invalid remote source device.");
}
- WorkerInterface* rwi = cache_->CreateWorker(call->src_worker_);
+ WorkerSession* sess = session();
+ WorkerInterface* rwi = sess->worker_cache->CreateWorker(call->src_worker_);
if (s.ok() && rwi == nullptr) {
s = errors::Internal("No worker known as ", call->src_worker_);
}
Device* dst_device;
if (s.ok()) {
- s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_device);
+ s = sess->device_mgr->LookupDevice(parsed.dst_device, &dst_device);
}
if (!s.ok()) {
- get_call_freelist()->Release(call, cache_);
+ if (rwi != nullptr) {
+ sess->worker_cache->ReleaseWorker(call->src_worker_, rwi);
+ }
+ get_call_freelist()->Release(call, sess->worker_cache.get());
done(s, Args(), recv_args, Tensor{}, false);
return;
}
@@ -314,26 +251,21 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync(
// current status should be bad.
Status s = call->status();
call->done()(s, Args(), call->recv_args(), call->tensor(), call->is_dead());
- cache_->ReleaseWorker(call->src_worker_, call->wi_);
+ session()->worker_cache->ReleaseWorker(call->src_worker_, call->wi_);
call->wi_ = nullptr;
- get_call_freelist()->Release(call, cache_);
+ get_call_freelist()->Release(call, session()->worker_cache.get());
Unref();
});
}
} // namespace
-RpcRendezvousMgr::RpcRendezvousMgr(const WorkerEnv* env,
- const string& worker_name,
- WorkerCacheInterface* worker_cache)
- : BaseRendezvousMgr(env, worker_name),
- cache_(new WorkerFreeListCache(worker_cache)) {}
+RpcRendezvousMgr::RpcRendezvousMgr(const WorkerEnv* env)
+ : BaseRendezvousMgr(env) {}
BaseRemoteRendezvous* RpcRendezvousMgr::Create(int64 step_id,
- const WorkerEnv* worker_env,
- const string& worker_name) {
- return new RpcRemoteRendezvous(worker_env, worker_name, cache_.get(),
- step_id);
+ const WorkerEnv* worker_env) {
+ return new RpcRemoteRendezvous(worker_env, step_id);
}
} // end namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h
index 75dc62d98f..34c48a7917 100644
--- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h
+++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h
@@ -17,13 +17,13 @@ limitations under the License.
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_RPC_RENDEZVOUS_MGR_H_
#include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
-#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
-#include "tensorflow/core/distributed_runtime/worker_session.h"
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
+class DeviceMgr;
+
// RendezvousMgr keeps track of a set of local rendezvous instances.
// All tensors sent by this worker are buffered in a RendezvousMgr
// until the tensor is received. Each global unique "step_id"
@@ -44,17 +44,12 @@ namespace tensorflow {
// RendezvousMgr must have keys generated by Rendezvous::CreateKey.
class RpcRendezvousMgr : public BaseRendezvousMgr {
public:
- explicit RpcRendezvousMgr(const WorkerEnv* env, const string& worker_name,
- WorkerCacheInterface* worker_cache);
+ explicit RpcRendezvousMgr(const WorkerEnv* env);
protected:
- BaseRemoteRendezvous* Create(int64 step_id, const WorkerEnv* worker_env,
- const string& session_name) override;
+ BaseRemoteRendezvous* Create(int64 step_id, const WorkerEnv* worker_env);
private:
- // Private cache_ that allows us to reuse WorkerInterface objects.
- std::unique_ptr<WorkerCacheInterface> cache_;
-
TF_DISALLOW_COPY_AND_ASSIGN(RpcRendezvousMgr);
};
diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc
index 9b778eab3a..2d0d76623d 100644
--- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc
+++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc
@@ -68,9 +68,9 @@ class RpcRendezvousMgrTest : public ::testing::Test {
: cache_(new DummyWorkerCache),
worker_session_("/job:mnist/replica:1/task:2",
std::unique_ptr<WorkerCacheInterface>(cache_),
- std::unique_ptr<RendezvousMgrInterface>(),
+ std::unique_ptr<DeviceMgr>(),
std::unique_ptr<GraphMgr>()),
- rmgr_(&env, worker_session_.worker_name, cache_) {
+ rmgr_(&env) {
env.env = Env::Default();
}
@@ -87,7 +87,8 @@ TEST_F(RpcRendezvousMgrTest, LocalSendRecv) {
"/job:mnist/replica:1/task:2/cpu:0", 7890,
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
{
- Rendezvous* rendez = rmgr_.Find(step_id);
+ RemoteRendezvous* rendez = rmgr_.Find(step_id);
+ TF_ASSERT_OK(rendez->Initialize(&worker_session_));
core::ScopedUnref unref(rendez);
Rendezvous::Args args;
TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false));
@@ -107,7 +108,7 @@ TEST_F(RpcRendezvousMgrTest, LocalAbort) {
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
{ // Explicit Abort().
const int64 step_id = 123;
- Rendezvous* rendez = rmgr_.Find(step_id);
+ RemoteRendezvous* rendez = rmgr_.Find(step_id);
core::ScopedUnref unref(rendez);
SchedClosure([this, rendez]() {
env.env->SleepForMicroseconds(100 * 1000);
@@ -116,11 +117,12 @@ TEST_F(RpcRendezvousMgrTest, LocalAbort) {
Tensor val(DT_STRING);
bool val_dead = false;
Rendezvous::Args args;
+ TF_ASSERT_OK(rendez->Initialize(&worker_session_));
EXPECT_TRUE(errors::IsAborted(rendez->Recv(key, args, &val, &val_dead)));
}
{ // Cleanup causes Abort().
const int64 step_id = 321;
- Rendezvous* rendez = rmgr_.Find(step_id);
+ RemoteRendezvous* rendez = rmgr_.Find(step_id);
core::ScopedUnref unref(rendez);
SchedClosure([this, step_id]() {
env.env->SleepForMicroseconds(100 * 1000);
@@ -129,6 +131,7 @@ TEST_F(RpcRendezvousMgrTest, LocalAbort) {
Tensor val(DT_STRING);
bool val_dead = false;
Rendezvous::Args args;
+ TF_ASSERT_OK(rendez->Initialize(&worker_session_));
EXPECT_TRUE(errors::IsAborted(rendez->Recv(key, args, &val, &val_dead)));
}
}
@@ -139,7 +142,8 @@ TEST_F(RpcRendezvousMgrTest, CleanupAll) {
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
{
const int64 step_id = 123;
- Rendezvous* rendez = rmgr_.Find(step_id);
+ RemoteRendezvous* rendez = rmgr_.Find(step_id);
+ TF_ASSERT_OK(rendez->Initialize(&worker_session_));
core::ScopedUnref unref(rendez);
Rendezvous::Args args;
TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false));
@@ -168,10 +172,11 @@ TEST_F(RpcRendezvousMgrTest, TransferDummyDeviceContext) {
"/job:mnist/replica:1/task:2/cpu:0", 7890,
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
{
- Rendezvous* rendez = rmgr_.Find(step_id);
+ RemoteRendezvous* rendez = rmgr_.Find(step_id);
core::ScopedUnref unref(rendez);
Rendezvous::Args args;
args.device_context = dc;
+ TF_ASSERT_OK(rendez->Initialize(&worker_session_));
TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false));
}
{
diff --git a/tensorflow/core/distributed_runtime/session_mgr.cc b/tensorflow/core/distributed_runtime/session_mgr.cc
index e2be62f816..22551d5482 100644
--- a/tensorflow/core/distributed_runtime/session_mgr.cc
+++ b/tensorflow/core/distributed_runtime/session_mgr.cc
@@ -17,8 +17,9 @@ limitations under the License.
#include <utility>
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/renamed_device.h"
#include "tensorflow/core/distributed_runtime/graph_mgr.h"
-#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow {
@@ -26,23 +27,12 @@ namespace tensorflow {
SessionMgr::SessionMgr(
WorkerEnv* worker_env, const string& default_worker_name,
std::unique_ptr<WorkerCacheInterface> default_worker_cache,
- std::unique_ptr<RendezvousMgrInterface> default_rendezvous_mgr,
- WorkerCacheFactory worker_cache_factory)
- : SessionMgr(
- worker_env, default_worker_name, std::move(default_worker_cache),
- default_rendezvous_mgr.release(), std::move(worker_cache_factory)) {}
-
-SessionMgr::SessionMgr(
- WorkerEnv* worker_env, const string& default_worker_name,
- std::unique_ptr<WorkerCacheInterface> default_worker_cache,
- RendezvousMgrInterface* default_rendezvous_mgr,
WorkerCacheFactory worker_cache_factory)
: worker_env_(worker_env),
- legacy_session_(
- default_worker_name, std::move(default_worker_cache),
- std::unique_ptr<RendezvousMgrInterface>(default_rendezvous_mgr),
- std::unique_ptr<GraphMgr>(
- new GraphMgr(worker_env, default_rendezvous_mgr))),
+ legacy_session_(default_worker_name, std::move(default_worker_cache),
+ std::unique_ptr<DeviceMgr>(worker_env->device_mgr),
+ std::unique_ptr<GraphMgr>(
+ new GraphMgr(worker_env, worker_env->device_mgr))),
worker_cache_factory_(std::move(worker_cache_factory)) {}
string SessionMgr::WorkerNameFromServerDef(const ServerDef& server_def) {
@@ -53,20 +43,28 @@ string SessionMgr::WorkerNameFromServerDef(const ServerDef& server_def) {
Status SessionMgr::CreateSession(const string& session,
const ServerDef& server_def) {
mutex_lock l(mu_);
+ if (session.empty()) {
+ return errors::InvalidArgument("Session must be non-empty.");
+ }
+
const string worker_name = WorkerNameFromServerDef(server_def);
WorkerCacheInterface* worker_cache = nullptr;
TF_RETURN_IF_ERROR(worker_cache_factory_(server_def, &worker_cache));
- std::unique_ptr<RendezvousMgrInterface> rendezvous_mgr(
- new RpcRendezvousMgr(worker_env_, worker_name, worker_cache));
+ std::vector<Device*> renamed_devices;
+ for (Device* d : worker_env_->local_devices) {
+ renamed_devices.push_back(
+ RenamedDevice::NewRenamedDevice(worker_name, d, false));
+ }
+ std::unique_ptr<DeviceMgr> device_mgr(new DeviceMgr(renamed_devices));
std::unique_ptr<GraphMgr> graph_mgr(
- new GraphMgr(worker_env_, rendezvous_mgr.get()));
+ new GraphMgr(worker_env_, device_mgr.get()));
std::unique_ptr<WorkerSession> worker_session(new WorkerSession(
worker_name, std::unique_ptr<WorkerCacheInterface>(worker_cache),
- std::move(rendezvous_mgr), std::move(graph_mgr)));
+ std::move(device_mgr), std::move(graph_mgr)));
sessions_.insert(std::make_pair(session, std::move(worker_session)));
return Status::OK();
@@ -78,22 +76,6 @@ Status SessionMgr::DeleteSession(const string& session) {
if (it != sessions_.end()) {
sessions_.erase(it);
}
- std::set<string> graph_handles;
- for (auto graph_handle_it = sessions_by_graph_handle_.begin();
- graph_handle_it != sessions_by_graph_handle_.end(); ++graph_handle_it) {
- if (graph_handle_it->second == session) {
- graph_handles.insert(graph_handle_it->first);
- graph_handle_it = sessions_by_graph_handle_.erase(graph_handle_it);
- if (graph_handle_it == sessions_by_graph_handle_.end()) break;
- }
- }
- for (auto step_id_it = graphs_by_step_id_.begin();
- step_id_it != graphs_by_step_id_.end(); ++step_id_it) {
- if (graph_handles.find(step_id_it->second) != graph_handles.end()) {
- step_id_it = graphs_by_step_id_.erase(step_id_it);
- if (step_id_it == graphs_by_step_id_.end()) break;
- }
- }
return Status::OK();
}
@@ -114,58 +96,4 @@ WorkerSession* SessionMgr::WorkerSessionForSession(const string& session) {
WorkerSession* SessionMgr::LegacySession() { return &legacy_session_; }
-WorkerSession* SessionMgr::WorkerSessionForGraphHandleUnlocked(
- const string& graph_handle) {
- auto it = sessions_by_graph_handle_.find(graph_handle);
- if (it == sessions_by_graph_handle_.end()) {
- return &legacy_session_;
- } else {
- return WorkerSessionForSessionUnlocked(it->second);
- }
-}
-
-WorkerSession* SessionMgr::WorkerSessionForGraphHandle(
- const string& graph_handle) {
- mutex_lock l(mu_);
- return WorkerSessionForGraphHandleUnlocked(graph_handle);
-}
-
-WorkerSession* SessionMgr::WorkerSessionForStepId(const int64 step_id) {
- mutex_lock l(mu_);
- auto it = graphs_by_step_id_.find(step_id);
- if (it == graphs_by_step_id_.end()) {
- return &legacy_session_;
- } else {
- return WorkerSessionForGraphHandleUnlocked(it->second);
- }
-}
-
-void SessionMgr::AssociateGraphWithSession(const string& session,
- const string& graph_handle) {
- mutex_lock l(mu_);
- sessions_by_graph_handle_[graph_handle] = session;
-}
-
-void SessionMgr::DisassociateGraphFromSession(const string& graph_handle) {
- mutex_lock l(mu_);
- auto it = sessions_by_graph_handle_.find(graph_handle);
- if (it != sessions_by_graph_handle_.end()) {
- sessions_by_graph_handle_.erase(it);
- }
-}
-
-void SessionMgr::AssociateStepIdWithGraph(const string& graph_handle,
- const int64 step_id) {
- mutex_lock l(mu_);
- graphs_by_step_id_[step_id] = graph_handle;
-}
-
-void SessionMgr::DisassociateStepIdFromGraph(const int64 step_id) {
- mutex_lock l(mu_);
- auto it = graphs_by_step_id_.find(step_id);
- if (it != graphs_by_step_id_.end()) {
- graphs_by_step_id_.erase(it);
- }
-}
-
} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/session_mgr.h b/tensorflow/core/distributed_runtime/session_mgr.h
index 455b5c8d9d..c44bca7b7a 100644
--- a/tensorflow/core/distributed_runtime/session_mgr.h
+++ b/tensorflow/core/distributed_runtime/session_mgr.h
@@ -30,6 +30,8 @@ struct WorkerEnv;
// SessionMgr keeps track of information related to a given session.
//
+// SessionMgr runs on the workers.
+//
// SessionMgr is threadsafe.
class SessionMgr {
public:
@@ -39,7 +41,6 @@ class SessionMgr {
explicit SessionMgr(
WorkerEnv* worker_env, const string& default_worker_name,
std::unique_ptr<WorkerCacheInterface> default_worker_cache,
- std::unique_ptr<RendezvousMgrInterface> default_rendezvous_mgr,
WorkerCacheFactory worker_cache_factory);
~SessionMgr() {}
@@ -50,49 +51,36 @@ class SessionMgr {
WorkerSession* WorkerSessionForSession(const string& session);
WorkerSession* LegacySession();
- // Locates the worker session for a given graph handle
- WorkerSession* WorkerSessionForGraphHandle(const string& graph_handle);
- void AssociateGraphWithSession(const string& session,
- const string& graph_handle);
- void DisassociateGraphFromSession(const string& graph_handle);
-
- // Locates a worker session for a given step id
- WorkerSession* WorkerSessionForStepId(const int64 step_id);
- void AssociateStepIdWithGraph(const string& graph_handle,
- const int64 step_id);
- void DisassociateStepIdFromGraph(const int64 step_id);
-
Status DeleteSession(const string& session);
static string WorkerNameFromServerDef(const ServerDef& server_def);
private:
- // Private constructor to work around std::unique_ptr ownership issues.
- explicit SessionMgr(
- WorkerEnv* worker_env, const string& default_worker_name,
- std::unique_ptr<WorkerCacheInterface> default_worker_cache,
- RendezvousMgrInterface* default_rendezvous_mgr,
- WorkerCacheFactory worker_cache_factory);
-
const WorkerEnv* const worker_env_; // Not owned.
+
+ // A note about destruction:
+ // We must delete graph_mgr before device_mgr, due to shared
+ // ownership of OpKernels in the executors. (The graph_mgr will
+ // free all stateless OpKernels, and pass over borrowed stateful
+ // OpKernels, which are also held in their respective devices'
+ // OpSegments.)
+ //
+ // legacy_session_ owns the worker_env_.device_mgr, and so we must ensure
+ // that sessions_'s WorkerSessions are deleted (which do not own the
+ // underlying devices, but instead own RenamedDevices) before
+ // legacy_session_ is deleted. Further, we must ensure that WorkerSession's
+ // device_mgr is deleted after WorkerSession's graph_mgr.
+
WorkerSession legacy_session_;
const WorkerCacheFactory worker_cache_factory_;
WorkerSession* WorkerSessionForSessionUnlocked(const string& session)
EXCLUSIVE_LOCKS_REQUIRED(mu_);
- WorkerSession* WorkerSessionForGraphHandleUnlocked(const string& graph_handle)
- EXCLUSIVE_LOCKS_REQUIRED(mu_);
mutex mu_;
// A map from session identifier to internal session structure.
std::map<string, std::unique_ptr<WorkerSession>> sessions_ GUARDED_BY(mu_);
-
- // A map from graph handles to the session that they belong to.
- std::map<string, string> sessions_by_graph_handle_ GUARDED_BY(mu_);
-
- // A map from globally-unique step id's to the corresponding graph handles.
- std::map<int64, string> graphs_by_step_id_ GUARDED_BY(mu_);
};
} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/session_mgr_test.cc b/tensorflow/core/distributed_runtime/session_mgr_test.cc
index d3f3fa8395..7132f123a5 100644
--- a/tensorflow/core/distributed_runtime/session_mgr_test.cc
+++ b/tensorflow/core/distributed_runtime/session_mgr_test.cc
@@ -27,8 +27,6 @@ class SessionMgrTest : public ::testing::Test {
SessionMgrTest()
: mgr_(&env_, "/job:mnist/replica:0/task:0",
std::unique_ptr<WorkerCacheInterface>(),
- std::unique_ptr<RendezvousMgrInterface>(new RpcRendezvousMgr(
- &env_, "/job:mnist/replica:0/task:0", nullptr)),
factory_),
legacy_session_(mgr_.WorkerSessionForSession("novel_session_id")) {}
@@ -48,90 +46,19 @@ TEST_F(SessionMgrTest, CreateSessionSimple) {
TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def));
WorkerSession* session = mgr_.WorkerSessionForSession(session_handle);
EXPECT_NE(nullptr, session) << "Session for " << session_handle << "was null";
-
- TF_EXPECT_OK(mgr_.DeleteSession(session_handle));
-}
-
-TEST_F(SessionMgrTest, AssociateGraphWithSession) {
- ServerDef server_def;
- string session_handle = "test_session_handle";
- TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def));
- WorkerSession* session = mgr_.WorkerSessionForSession(session_handle);
- ASSERT_NE(nullptr, session) << "Session for " << session_handle << "was null";
-
- string graph_handle = "test_graph_handle";
- mgr_.AssociateGraphWithSession(session_handle, graph_handle);
- WorkerSession* graph_session = mgr_.WorkerSessionForGraphHandle(graph_handle);
- ASSERT_EQ(session, graph_session);
-
+ EXPECT_NE(mgr_.LegacySession(), session);
TF_EXPECT_OK(mgr_.DeleteSession(session_handle));
}
-TEST_F(SessionMgrTest, AssociateStepWithGraph) {
+TEST_F(SessionMgrTest, LegacySession) {
ServerDef server_def;
- string session_handle = "test_session_handle";
- TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def));
+ string session_handle = "";
WorkerSession* session = mgr_.WorkerSessionForSession(session_handle);
- ASSERT_NE(nullptr, session) << "Session for " << session_handle << "was null";
-
- string graph_handle = "test_graph_handle";
- mgr_.AssociateGraphWithSession(session_handle, graph_handle);
- WorkerSession* graph_session = mgr_.WorkerSessionForGraphHandle(graph_handle);
- ASSERT_EQ(session, graph_session);
-
- int64 step_id = 1234567890L;
- mgr_.AssociateStepIdWithGraph(graph_handle, step_id);
- WorkerSession* step_session = mgr_.WorkerSessionForStepId(step_id);
- ASSERT_EQ(session, step_session);
- ASSERT_EQ(graph_session, step_session);
+ EXPECT_EQ(mgr_.LegacySession(), session);
TF_EXPECT_OK(mgr_.DeleteSession(session_handle));
}
-TEST_F(SessionMgrTest, AssociateGraphWithSession_MissingSession) {
- string session_handle = "test_session_handle";
- string graph_handle = "test_graph_handle";
- mgr_.AssociateGraphWithSession(session_handle, graph_handle);
- WorkerSession* graph_session = mgr_.WorkerSessionForGraphHandle(graph_handle);
- ASSERT_EQ(legacy_session_, graph_session);
-}
-
-TEST_F(SessionMgrTest, AssociateStepWithGraph_MissingGraph) {
- ServerDef server_def;
- string session_handle = "test_session_handle";
- TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def));
- WorkerSession* session = mgr_.WorkerSessionForSession(session_handle);
- ASSERT_NE(nullptr, session) << "Session for " << session_handle << "was null";
-
- string graph_handle = "test_graph_handle";
- int64 step_id = 1234567890L;
- mgr_.AssociateStepIdWithGraph(graph_handle, step_id);
- WorkerSession* step_session = mgr_.WorkerSessionForStepId(step_id);
- ASSERT_EQ(legacy_session_, step_session);
-}
-
-TEST_F(SessionMgrTest, AssociateStepWithGraph_MissingSession) {
- string session_handle = "test_session_handle";
- string graph_handle = "test_graph_handle";
- mgr_.AssociateGraphWithSession(session_handle, graph_handle);
- WorkerSession* graph_session = mgr_.WorkerSessionForGraphHandle(graph_handle);
- ASSERT_EQ(legacy_session_, graph_session);
-
- int64 step_id = 1234567890L;
- mgr_.AssociateStepIdWithGraph(graph_handle, step_id);
- WorkerSession* step_session = mgr_.WorkerSessionForStepId(step_id);
- ASSERT_EQ(legacy_session_, step_session);
-}
-
-TEST_F(SessionMgrTest, AssociateStepWithGraph_MissingSessionAndGraph) {
- string session_handle = "test_session_handle";
- string graph_handle = "test_graph_handle";
- int64 step_id = 1234567890L;
- mgr_.AssociateStepIdWithGraph(graph_handle, step_id);
- WorkerSession* step_session = mgr_.WorkerSessionForStepId(step_id);
- ASSERT_EQ(legacy_session_, step_session);
-}
-
TEST_F(SessionMgrTest, WorkerNameFromServerDef) {
ServerDef server_def;
server_def.set_job_name("worker");
diff --git a/tensorflow/core/distributed_runtime/worker.cc b/tensorflow/core/distributed_runtime/worker.cc
index 89639e21b5..07bb17981d 100644
--- a/tensorflow/core/distributed_runtime/worker.cc
+++ b/tensorflow/core/distributed_runtime/worker.cc
@@ -56,10 +56,6 @@ void Worker::RegisterGraphAsync(const RegisterGraphRequest* request,
Status s = session->graph_mgr->Register(
request->session_handle(), request->graph_def(), request->graph_options(),
request->debug_options(), response->mutable_graph_handle());
- if (s.ok()) {
- env_->session_mgr->AssociateGraphWithSession(request->session_handle(),
- response->graph_handle());
- }
done(s);
}
@@ -67,9 +63,8 @@ void Worker::DeregisterGraphAsync(const DeregisterGraphRequest* request,
DeregisterGraphResponse* response,
StatusCallback done) {
WorkerSession* session =
- env_->session_mgr->WorkerSessionForGraphHandle(request->graph_handle());
+ env_->session_mgr->WorkerSessionForSession(request->session_handle());
Status s = session->graph_mgr->Deregister(request->graph_handle());
- env_->session_mgr->DisassociateGraphFromSession(request->graph_handle());
done(s);
}
@@ -141,8 +136,7 @@ void Worker::SetOrCallFinalCallback(const string& graph_handle, int step_id,
}
void Worker::AbortStep(int64 step_id) {
- WorkerSession* session = env_->session_mgr->WorkerSessionForStepId(step_id);
- Rendezvous* rendez = session->rendezvous_mgr->Find(step_id);
+ Rendezvous* rendez = env_->rendezvous_mgr->Find(step_id);
SchedNonBlockingClosureAfter(1000000, [rendez, step_id]() {
// Delay a bit before aborting the step. This way, the root
// cause may return first back to the client instead of this
@@ -193,8 +187,7 @@ void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request,
const int64 step_id = request->step_id();
TRACEPRINTF("RunGraph: %lld", step_id);
WorkerSession* session =
- env_->session_mgr->WorkerSessionForGraphHandle(request->graph_handle());
- env_->session_mgr->AssociateStepIdWithGraph(request->graph_handle(), step_id);
+ env_->session_mgr->WorkerSessionForSession(request->session_handle());
GraphMgr::NamedTensors in;
GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
Status s = PrepareRunGraph(request, &in, out);
@@ -231,8 +224,8 @@ void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request,
}
CostGraphDef* cost_graph = response->mutable_cost_graph();
session->graph_mgr->ExecuteAsync(
- request->graph_handle(), step_id, request->exec_opts(), collector,
- cost_graph, cm, in,
+ request->graph_handle(), step_id, session, request->exec_opts(),
+ collector, cost_graph, cm, in,
[this, step_id, response, session, cm, out, token, collector, opts,
done](Status s) {
if (s.ok()) {
@@ -267,8 +260,8 @@ void Worker::DoPartialRunGraph(CallOptions* opts,
const string& graph_handle = request->graph_handle();
TRACEPRINTF("PartialRunGraph: %lld", step_id);
WorkerSession* session =
- env_->session_mgr->WorkerSessionForGraphHandle(graph_handle);
- env_->session_mgr->AssociateStepIdWithGraph(graph_handle, step_id);
+ env_->session_mgr->WorkerSessionForSession(request->session_handle());
+
GraphMgr::NamedTensors in;
GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
Status s = PrepareRunGraph(request, &in, out);
@@ -315,8 +308,8 @@ void Worker::DoPartialRunGraph(CallOptions* opts,
[cm]() { cm->StartCancel(); });
}
session->graph_mgr->ExecuteAsync(
- graph_handle, step_id, request->exec_opts(), nullptr /* collector */,
- nullptr /* cost_graph */, cm, in,
+ graph_handle, step_id, session, request->exec_opts(),
+ nullptr /* collector */, nullptr /* cost_graph */, cm, in,
[this, token, graph_handle, step_id, cm](Status s) {
{
mutex_lock l(mu_);
@@ -365,8 +358,7 @@ void Worker::CleanupGraphAsync(const CleanupGraphRequest* request,
CleanupGraphResponse* response,
StatusCallback done) {
const int64 step_id = request->step_id();
- WorkerSession* session = env_->session_mgr->WorkerSessionForStepId(step_id);
- session->rendezvous_mgr->Cleanup(step_id);
+ env_->rendezvous_mgr->Cleanup(step_id);
done(Status::OK());
}
@@ -394,8 +386,8 @@ void Worker::TracingAsync(const TracingRequest* request,
Status Worker::PrepareRecvTensor(const Rendezvous::ParsedKey& parsed,
Device** src_dev) {
// Figures out which device the tensor is hosted on.
- TF_RETURN_IF_ERROR(
- env_->device_mgr->LookupDevice(parsed.src_device, src_dev));
+ string local_name = DeviceNameUtils::LocalName(parsed.src_device);
+ TF_RETURN_IF_ERROR(env_->device_mgr->LookupDevice(local_name, src_dev));
// Does the device have the right incarnation number we expect?
if ((*src_dev)->attributes().incarnation() != parsed.src_incarnation) {
diff --git a/tensorflow/core/distributed_runtime/worker_env.h b/tensorflow/core/distributed_runtime/worker_env.h
index 24fb5948a7..f09bea328f 100644
--- a/tensorflow/core/distributed_runtime/worker_env.h
+++ b/tensorflow/core/distributed_runtime/worker_env.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_ENV_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_ENV_H_
+#include <vector>
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
@@ -24,8 +25,10 @@ namespace thread {
class ThreadPool;
} // namespace thread
+class Device;
class DeviceMgr;
class Env;
+class RendezvousMgrInterface;
class SessionMgr;
// The worker environment class, which holds a bag of pointers to
@@ -38,10 +41,18 @@ struct WorkerEnv {
// session_mgr encapsulates state for each session.
SessionMgr* session_mgr = nullptr;
+ // The local devices of this worker. Devices are owned by the device_mgr.
+ //
+ // REQUIRES: !local_devices.empty().
+ std::vector<Device*> local_devices;
+
// device_mgr manages local devices (cpu and gpu). The WorkerService
// is the network interface for managed devices.
DeviceMgr* device_mgr = nullptr;
+ // A set of rendezvous keyed by step ids.
+ RendezvousMgrInterface* rendezvous_mgr = nullptr;
+
// A pool of threads for scheduling compute work.
thread::ThreadPool* compute_pool = nullptr;
};
diff --git a/tensorflow/core/distributed_runtime/worker_interface.h b/tensorflow/core/distributed_runtime/worker_interface.h
index 508bc7f468..c9db28ec67 100644
--- a/tensorflow/core/distributed_runtime/worker_interface.h
+++ b/tensorflow/core/distributed_runtime/worker_interface.h
@@ -113,6 +113,11 @@ class WorkerInterface {
return CallAndWait(&ME::GetStatusAsync, request, response);
}
+ Status CreateWorkerSession(const CreateWorkerSessionRequest* request,
+ CreateWorkerSessionResponse* response) {
+ return CallAndWait(&ME::CreateWorkerSessionAsync, request, response);
+ }
+
Status RegisterGraph(const RegisterGraphRequest* request,
RegisterGraphResponse* response) {
return CallAndWait(&ME::RegisterGraphAsync, request, response);
diff --git a/tensorflow/core/distributed_runtime/worker_session.cc b/tensorflow/core/distributed_runtime/worker_session.cc
index 8298e16959..8691450e9b 100644
--- a/tensorflow/core/distributed_runtime/worker_session.cc
+++ b/tensorflow/core/distributed_runtime/worker_session.cc
@@ -17,14 +17,84 @@ limitations under the License.
namespace tensorflow {
-WorkerSession::WorkerSession(
- const string& worker_name,
- std::unique_ptr<WorkerCacheInterface> worker_cache,
- std::unique_ptr<RendezvousMgrInterface> rendezvous_mgr,
- std::unique_ptr<GraphMgr> graph_mgr)
+namespace {
+
+// A private cache that wraps worker_cache and allows reuse of
+// WorkerInterface objects.
+class WorkerFreeListCache : public WorkerCacheInterface {
+ public:
+ explicit WorkerFreeListCache(std::unique_ptr<WorkerCacheInterface> w)
+ : wrapped_(std::move(w)) {}
+
+ ~WorkerFreeListCache() final {
+ for (auto p : workers_) {
+ wrapped_->ReleaseWorker(p.first, p.second.worker);
+ }
+ }
+
+ void ListWorkers(std::vector<string>* workers) const override {
+ wrapped_->ListWorkers(workers);
+ }
+
+ WorkerInterface* CreateWorker(const string& target) override {
+ mutex_lock l(mu_);
+ auto p = workers_.find(target);
+ if (p != workers_.end()) {
+ return p->second.worker;
+ }
+ WorkerState state;
+ state.worker = wrapped_->CreateWorker(target);
+ if (state.worker != nullptr) {
+ workers_.insert(std::make_pair(target, state));
+ }
+ return state.worker;
+ }
+
+ void ReleaseWorker(const string& target, WorkerInterface* worker) override {
+ // TODO(jeff,sanjay): Should decrement ref-count when we implement eviction.
+ }
+
+ bool GetDeviceLocalityNonBlocking(const string& device,
+ DeviceLocality* locality) override {
+ return wrapped_->GetDeviceLocalityNonBlocking(device, locality);
+ }
+
+ void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality,
+ StatusCallback done) override {
+ wrapped_->GetDeviceLocalityAsync(device, locality, done);
+ }
+
+ void SetLogging(bool active) override { wrapped_->SetLogging(active); }
+
+ void ClearLogs() override { wrapped_->ClearLogs(); }
+
+ bool RetrieveLogs(int64 step_id, StepStats* ss) override {
+ return wrapped_->RetrieveLogs(step_id, ss);
+ }
+
+ private:
+ std::unique_ptr<WorkerCacheInterface> wrapped_;
+
+ // Information kept per created WorkerInterface.
+ struct WorkerState {
+ WorkerInterface* worker;
+ // TODO(jeff,sanjay): Add reference count if we support eviction.
+ };
+
+ // TODO(jeff,sanjay): Eviction when the map becomes too big.
+ mutex mu_;
+ std::unordered_map<string, WorkerState> workers_ GUARDED_BY(mu_);
+};
+
+} // namespace
+
+WorkerSession::WorkerSession(const string& worker_name,
+ std::unique_ptr<WorkerCacheInterface> worker_cache,
+ std::unique_ptr<DeviceMgr> device_mgr,
+ std::unique_ptr<GraphMgr> graph_mgr)
: worker_name(worker_name),
- worker_cache(std::move(worker_cache)),
- rendezvous_mgr(std::move(rendezvous_mgr)),
+ worker_cache(new WorkerFreeListCache(std::move(worker_cache))),
+ device_mgr(std::move(device_mgr)),
graph_mgr(std::move(graph_mgr)) {}
} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/worker_session.h b/tensorflow/core/distributed_runtime/worker_session.h
index e6ebe88329..77cf4de8f7 100644
--- a/tensorflow/core/distributed_runtime/worker_session.h
+++ b/tensorflow/core/distributed_runtime/worker_session.h
@@ -18,14 +18,13 @@ limitations under the License.
#include <string>
+#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/distributed_runtime/graph_mgr.h"
-#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
namespace tensorflow {
class GraphMgr;
-class RendezvousMgrInterface;
class WorkerCacheInterface;
// WorkerSession encapsulates all of the state relating to a given session.
@@ -36,17 +35,20 @@ struct WorkerSession {
// Object from which WorkerInterface instances can be obtained.
const std::unique_ptr<WorkerCacheInterface> worker_cache;
- // A set of rendezvous keyed by step ids.
- const std::unique_ptr<RendezvousMgrInterface> rendezvous_mgr;
+ // Collection of local devices. These devices are typically RenamedDevices
+ // in all except the SessionMgr.legacy_session_. legacy_session_.device_mgr
+ // == worker_env_.device_mgr, which holds the true devices.
+ const std::unique_ptr<DeviceMgr> device_mgr;
// graph_mgr keeps track of the registered graphs of this session.
//
// Note: graph_mgr must be deleted before rendezvous_mgr!
+ // Note: graph_mgr must be deleted before device_mgr!
const std::unique_ptr<GraphMgr> graph_mgr;
WorkerSession(const string& worker_name,
std::unique_ptr<WorkerCacheInterface> worker_cache,
- std::unique_ptr<RendezvousMgrInterface> rendezvous_mgr,
+ std::unique_ptr<DeviceMgr> device_mgr,
std::unique_ptr<GraphMgr> graph_mgr);
};
diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h
index 8894671fdf..27fe28fe60 100644
--- a/tensorflow/core/framework/device_base.h
+++ b/tensorflow/core/framework/device_base.h
@@ -115,7 +115,7 @@ class DeviceBase {
cpu_worker_threads_ = t;
}
- const CpuWorkerThreads* tensorflow_cpu_worker_threads() const {
+ virtual const CpuWorkerThreads* tensorflow_cpu_worker_threads() const {
CHECK(cpu_worker_threads_ != nullptr);
return cpu_worker_threads_;
}
@@ -140,7 +140,7 @@ class DeviceBase {
gpu_device_info_ = g;
}
- const GpuDeviceInfo* tensorflow_gpu_device_info() const {
+ virtual const GpuDeviceInfo* tensorflow_gpu_device_info() const {
return gpu_device_info_;
}
@@ -170,13 +170,13 @@ class DeviceBase {
return GetAllocator(attr);
}
- const Eigen::ThreadPoolDevice* eigen_cpu_device() {
+ virtual const Eigen::ThreadPoolDevice* eigen_cpu_device() {
CHECK(eigen_cpu_device_ != nullptr);
return eigen_cpu_device_;
}
#ifdef TENSORFLOW_USE_SYCL
- const Eigen::SyclDevice* eigen_sycl_device() const {
+ virtual const Eigen::SyclDevice* eigen_sycl_device() const {
CHECK(eigen_sycl_device_ != nullptr);
return eigen_sycl_device_;
}
diff --git a/tensorflow/core/protobuf/cluster.proto b/tensorflow/core/protobuf/cluster.proto
new file mode 100644
index 0000000000..33c87eefe0
--- /dev/null
+++ b/tensorflow/core/protobuf/cluster.proto
@@ -0,0 +1,82 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+syntax = "proto3";
+
+package tensorflow;
+option cc_enable_arenas = true;
+option java_outer_classname = "ClusterProtos";
+option java_multiple_files = true;
+option java_package = "org.tensorflow.distruntime";
+
+// This file contains protos to be used when defining a TensorFlow
+// cluster.
+//
+// EXAMPLES
+// --------
+//
+// 1. A single-process cluster, containing "/job:local/task:0".
+//
+// Cluster:
+// job { name: 'local' tasks { key: 0 value: 'localhost:2222' } }
+//
+// Server:
+// cluster { $CLUSTER } job_name: 'local' task_index: 0
+//
+// 2. A two-process cluster, containing "/job:local/task:{0,1}".
+//
+// Cluster:
+// job { name: 'local' tasks { key: 0 value: 'localhost:2222' }
+// tasks { key: 1 value: 'localhost:2223' } }
+//
+// Servers:
+// cluster { $CLUSTER } job_name: 'local' task_index: 0
+// cluster { $CLUSTER } job_name: 'local' task_index: 1
+//
+// 3. A two-job cluster, containing "/job:worker/task:{0,1,2}" and
+// "/job:ps/task:{0,1}".
+//
+// Cluster:
+// job { name: 'worker' tasks { key: 0 value: 'worker1:2222' }
+// tasks { key: 1 value: 'worker2:2222' }
+// tasks { key: 2 value: 'worker3:2222' } }
+// job { name: 'ps' tasks { key: 0 value: 'ps0:2222' }
+// tasks { key: 1 value: 'ps1:2222' } }
+//
+// Servers:
+// cluster { $CLUSTER } job_name: 'worker' task_index: 0
+// cluster { $CLUSTER } job_name: 'worker' task_index: 1
+// cluster { $CLUSTER } job_name: 'worker' task_index: 2
+// cluster { $CLUSTER } job_name: 'ps' task_index: 0
+// cluster { $CLUSTER } job_name: 'ps' task_index: 1
+
+// Defines a single job in a TensorFlow cluster.
+message JobDef {
+ // The name of this job.
+ string name = 1;
+
+ // Mapping from task ID to "hostname:port" string.
+ //
+ // If the `name` field contains "worker", and the `tasks` map contains a
+ // mapping from 7 to "example.org:2222", then the device prefix
+ // "/job:worker/task:7" will be assigned to "example.org:2222".
+ map<int32, string> tasks = 2;
+}
+
+// Defines a TensorFlow cluster as a set of jobs.
+message ClusterDef {
+ // The jobs that comprise the cluster.
+ repeated JobDef job = 1;
+}
diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto
index 5c0f7232eb..630f47633f 100644
--- a/tensorflow/core/protobuf/config.proto
+++ b/tensorflow/core/protobuf/config.proto
@@ -10,6 +10,7 @@ import "tensorflow/core/framework/cost_graph.proto";
import "tensorflow/core/framework/graph.proto";
import "tensorflow/core/framework/step_stats.proto";
import "tensorflow/core/protobuf/debug.proto";
+import "tensorflow/core/protobuf/cluster.proto";
import "tensorflow/core/protobuf/rewriter_config.proto";
message GPUOptions {
@@ -259,6 +260,11 @@ message ConfigProto {
// Options that apply when this session uses the distributed runtime.
RPCOptions rpc_options = 13;
+
+ // Optional list of all workers to use in this session.
+ ClusterDef cluster_def = 14;
+
+ // Next: 15
};
// Options for a single Run() call.
diff --git a/tensorflow/core/protobuf/master.proto b/tensorflow/core/protobuf/master.proto
index de91b6133e..e607b1c42a 100644
--- a/tensorflow/core/protobuf/master.proto
+++ b/tensorflow/core/protobuf/master.proto
@@ -38,6 +38,9 @@ message CreateSessionRequest {
// Configuration options.
ConfigProto config = 2;
+
+ // The target string used from the client's perspective.
+ string target = 3;
}
message CreateSessionResponse {
diff --git a/tensorflow/core/protobuf/tensorflow_server.proto b/tensorflow/core/protobuf/tensorflow_server.proto
index c4077bd98e..6199e707e5 100644
--- a/tensorflow/core/protobuf/tensorflow_server.proto
+++ b/tensorflow/core/protobuf/tensorflow_server.proto
@@ -16,6 +16,7 @@ limitations under the License.
syntax = "proto3";
import "tensorflow/core/protobuf/config.proto";
+import "tensorflow/core/protobuf/cluster.proto";
package tensorflow;
option cc_enable_arenas = true;
@@ -23,69 +24,6 @@ option java_outer_classname = "ServerProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.distruntime";
-// This file contains protos to be used when defining a TensorFlow
-// cluster, and a server within that cluster.
-//
-// EXAMPLES
-// --------
-//
-// 1. A single-process cluster, containing "/job:local/task:0".
-//
-// Cluster:
-// job { name: 'local' tasks { key: 0 value: 'localhost:2222' } }
-//
-// Server:
-// cluster { $CLUSTER } job_name: 'local' task_index: 0
-//
-// 2. A two-process cluster, containing "/job:local/task:{0,1}".
-//
-// Cluster:
-// job { name: 'local' tasks { key: 0 value: 'localhost:2222' }
-// tasks { key: 1 value: 'localhost:2223' } }
-//
-// Servers:
-// cluster { $CLUSTER } job_name: 'local' task_index: 0
-// cluster { $CLUSTER } job_name: 'local' task_index: 1
-//
-// 3. A two-job cluster, containing "/job:worker/task:{0,1,2}" and
-// "/job:ps/task:{0,1}".
-//
-// Cluster:
-// job { name: 'worker' tasks { key: 0 value: 'worker1:2222' }
-// tasks { key: 1 value: 'worker2:2222' }
-// tasks { key: 2 value: 'worker3:2222' } }
-// job { name: 'ps' tasks { key: 0 value: 'ps0:2222' }
-// tasks { key: 1 value: 'ps1:2222' } }
-//
-// Servers:
-// cluster { $CLUSTER } job_name: 'worker' task_index: 0
-// cluster { $CLUSTER } job_name: 'worker' task_index: 1
-// cluster { $CLUSTER } job_name: 'worker' task_index: 2
-// cluster { $CLUSTER } job_name: 'ps' task_index: 0
-// cluster { $CLUSTER } job_name: 'ps' task_index: 1
-
-// Defines a single job in a TensorFlow cluster.
-message JobDef {
- // The name of this job.
- string name = 1;
-
- // Mapping from task ID to "hostname:port" string.
- //
- // If the `name` field contains "worker", and the `tasks` map contains a
- // mapping from 7 to "example.org:2222", then the device prefix
- // "/job:worker/task:7" will be assigned to "example.org:2222".
- //
- // NOTE(mrry): Currently, only a dense task ID space starting at 0 is
- // supported.
- map<int32, string> tasks = 2;
-}
-
-// Defines a TensorFlow cluster as a set of jobs.
-message ClusterDef {
- // The jobs that comprise the cluster.
- repeated JobDef job = 1;
-}
-
// Defines the configuration of a single TensorFlow server.
message ServerDef {
// The cluster of which this server is a member.
diff --git a/tensorflow/core/protobuf/worker.proto b/tensorflow/core/protobuf/worker.proto
index 661327847c..cf05aece39 100644
--- a/tensorflow/core/protobuf/worker.proto
+++ b/tensorflow/core/protobuf/worker.proto
@@ -119,6 +119,10 @@ message RegisterGraphResponse {
////////////////////////////////////////////////////////////////////////////////
message DeregisterGraphRequest {
+ // The session_handle used when registering the graph. If session_handle is
+ // empty, a single global namespace is used.
+ string session_handle = 2;
+
// REQUIRED: graph_handle must be returned by a RegisterGraph call
// to the same WorkerService.
string graph_handle = 1;
@@ -167,6 +171,12 @@ message ExecutorOpts {
};
message RunGraphRequest {
+ // session_handle is the the master-generated unique id for this session.
+ // If session_handle is non-empty, it must be the same as used when
+ // registering the graph. If it is empty, a single global namespace is used to
+ // search for the graph_handle.
+ string session_handle = 8;
+
// REQUIRED: graph_handle must be returned by a RegisterGraph call
// to the same WorkerService.
string graph_handle = 1;
@@ -193,6 +203,8 @@ message RunGraphRequest {
bool is_partial = 6;
// True if this is the last partial run request in a sequence of requests.
bool is_last_partial_run = 7;
+
+ // Next: 9
}
message RunGraphResponse {
diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py
index 864a96ef34..6336ca2310 100644
--- a/tensorflow/python/__init__.py
+++ b/tensorflow/python/__init__.py
@@ -55,6 +55,7 @@ from tensorflow.core.framework.summary_pb2 import *
from tensorflow.core.framework.attr_value_pb2 import *
from tensorflow.core.protobuf.meta_graph_pb2 import TensorInfo
from tensorflow.core.protobuf.config_pb2 import *
+from tensorflow.core.protobuf.tensorflow_server_pb2 import *
from tensorflow.core.protobuf.rewriter_config_pb2 import *
from tensorflow.core.util.event_pb2 import *
@@ -131,6 +132,7 @@ _allowed_symbols = [
'AttrValue',
'AutoParallelOptions',
'ConfigProto',
+ 'ClusterDef',
'DeviceSpec',
'Event',
'GPUOptions',
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index 9add5bd3cd..040cc33315 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -29,6 +29,7 @@ import six
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.core.lib.core import error_codes_pb2
+from tensorflow.core.protobuf import cluster_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import common_shapes
@@ -1789,7 +1790,7 @@ class SessionTest(test_util.TensorFlowTestCase):
with CaptureStderr() as log:
sess.run(c)
# Ensure that we did log device placement.
- self.assertTrue('/job:local/replica:0/task:0/cpu:0' in str(log))
+ self.assertTrue('/job:local/replica:0/task:0/cpu:0' in str(log), str(log))
def testLocalMasterSessionTimeout(self):
# Test that the timeout passed in a config to the session works correctly.
@@ -1834,6 +1835,270 @@ class SessionTest(test_util.TensorFlowTestCase):
server = server_lib.Server.create_local_server()
self.runTestBuildGraphError(session.Session(server.target))
+ def testClusterSpecPropagationSimple(self):
+ server1 = server_lib.Server.create_local_server()
+ server2 = server_lib.Server.create_local_server()
+ cluster_def = cluster_pb2.ClusterDef()
+ job = cluster_def.job.add()
+ job.name = 'worker'
+ job.tasks[0] = server1.target[len('grpc://'):]
+ job.tasks[1] = server2.target[len('grpc://'):]
+ config = config_pb2.ConfigProto(cluster_def=cluster_def)
+
+ const = constant_op.constant(17)
+ sess = session.Session(server1.target, config=config)
+ output = sess.run(const)
+ self.assertEqual(17, output)
+
+ def testClusterSpecPropagationWorker2Placement(self):
+ server1 = server_lib.Server.create_local_server()
+ server2 = server_lib.Server.create_local_server()
+ cluster_def = cluster_pb2.ClusterDef()
+ job = cluster_def.job.add()
+ job.name = 'worker'
+ job.tasks[0] = server1.target[len('grpc://'):]
+ job.tasks[1] = server2.target[len('grpc://'):]
+ config = config_pb2.ConfigProto(cluster_def=cluster_def)
+
+ with ops.Graph().as_default() as g, ops.device('/job:worker/task:1'):
+ const = constant_op.constant(17)
+ sess = session.Session(server1.target, config=config, graph=g)
+ run_options = config_pb2.RunOptions(
+ trace_level=config_pb2.RunOptions.FULL_TRACE)
+ run_metadata = config_pb2.RunMetadata()
+ output = sess.run(const, options=run_options, run_metadata=run_metadata)
+ self.assertEqual(17, output)
+ self.assertEqual(1,
+ len([
+ node_stats
+ for dev_stats in run_metadata.step_stats.dev_stats
+ for node_stats in dev_stats.node_stats
+ if '/job:worker/replica:0/task:1/device:CPU:0' ==
+ dev_stats.device and 'Const' == node_stats.node_name
+ ]))
+
+ def testClusterSpecPropagationWorker1Placement(self):
+ server1 = server_lib.Server.create_local_server()
+ server2 = server_lib.Server.create_local_server()
+ cluster_def = cluster_pb2.ClusterDef()
+ job = cluster_def.job.add()
+ job.name = 'worker'
+ job.tasks[0] = server1.target[len('grpc://'):]
+ job.tasks[1] = server2.target[len('grpc://'):]
+ config = config_pb2.ConfigProto(cluster_def=cluster_def)
+
+ with ops.Graph().as_default() as g, ops.device('/job:worker/task:0'):
+ const = constant_op.constant(17)
+ sess = session.Session(server1.target, config=config, graph=g)
+ output = sess.run(const)
+ self.assertEqual(17, output)
+
+ def testClusterSpecPropagationThreeServers2Graphs(self):
+ """Boots 3 servers, creates 2 sessions, ensures appropriate operations.
+
+ We create 2 clusterspecs:
+ 1. server2 as the master, server1 as a worker
+ 2. server2 as the master, server3 as a worker
+
+ We ensure that variables on the workers are independent.
+ """
+ server1 = server_lib.Server.create_local_server()
+ server2 = server_lib.Server.create_local_server()
+ server3 = server_lib.Server.create_local_server()
+ cluster_def1 = cluster_pb2.ClusterDef()
+ job1 = cluster_def1.job.add()
+ job1.name = 'worker1'
+ job1.tasks[0] = server2.target[len('grpc://'):]
+ job1.tasks[1] = server1.target[len('grpc://'):]
+
+ cluster_def2 = cluster_pb2.ClusterDef()
+ job2 = cluster_def2.job.add()
+ job2.name = 'worker2'
+ job2.tasks[0] = server2.target[len('grpc://'):]
+ job2.tasks[1] = server3.target[len('grpc://'):]
+
+ config1 = config_pb2.ConfigProto(cluster_def=cluster_def1)
+ config2 = config_pb2.ConfigProto(cluster_def=cluster_def2)
+
+ with ops.Graph().as_default() as g1:
+ with ops.device('/job:worker1/task:1'):
+ var1 = variables.Variable(array_ops.zeros([2]), name='var1')
+ update_op1 = state_ops.assign_add(
+ var1, array_ops.ones([2]), name='var1_assign_add')
+ init1 = variables.global_variables_initializer()
+
+ with ops.Graph().as_default() as g2:
+ with ops.device('/job:worker2/task:1'):
+ var2 = variables.Variable(array_ops.zeros([2]), name='var2')
+ update_op2 = state_ops.assign_add(
+ var2, array_ops.ones([2]), name='var2_assign_add')
+ init2 = variables.global_variables_initializer()
+
+ sess1 = session.Session(server2.target, graph=g1, config=config1)
+ sess2 = session.Session(server2.target, graph=g2, config=config2)
+
+ init1.run(session=sess1)
+ init2.run(session=sess2)
+
+ expected_zeros = np.zeros([2])
+ expected_ones = np.ones([2])
+
+ self.assertAllEqual(expected_zeros, sess1.run(var1))
+ self.assertAllEqual(expected_zeros, sess2.run(var2))
+
+ self.assertAllEqual(expected_ones, sess1.run(update_op1))
+ self.assertAllEqual(expected_ones, sess1.run(var1))
+ self.assertAllEqual(expected_zeros, sess2.run(var2))
+ self.assertAllEqual(expected_ones, sess2.run(update_op2))
+ self.assertAllEqual(expected_ones + expected_ones, sess1.run(update_op1))
+ self.assertAllEqual(expected_ones, sess2.run(var2))
+ self.assertAllEqual(expected_ones + expected_ones, sess1.run(var1))
+
+ def testClusterSpecPropagationThreeServers(self):
+ """Boots 3 servers, creates 2 sessions, ensures appropriate operations.
+
+ We create 2 clusterspecs:
+ 1. server2 as the master, server1 as a worker
+ 2. server2 as the master, server3 as a worker
+
+ We ensure that variables on the workers are independent.
+ """
+ server1 = server_lib.Server.create_local_server()
+ server2 = server_lib.Server.create_local_server()
+ server3 = server_lib.Server.create_local_server()
+ cluster_def1 = cluster_pb2.ClusterDef()
+ job1 = cluster_def1.job.add()
+ job1.name = 'worker'
+ job1.tasks[0] = server2.target[len('grpc://'):]
+ job1.tasks[1] = server1.target[len('grpc://'):]
+
+ cluster_def2 = cluster_pb2.ClusterDef()
+ job2 = cluster_def2.job.add()
+ job2.name = 'worker'
+ job2.tasks[0] = server2.target[len('grpc://'):]
+ job2.tasks[1] = server3.target[len('grpc://'):]
+
+ config1 = config_pb2.ConfigProto(cluster_def=cluster_def1)
+ config2 = config_pb2.ConfigProto(cluster_def=cluster_def2)
+
+ with ops.device('/job:worker/task:1'):
+ var = variables.Variable(array_ops.zeros([2]), name='var')
+ feed = array_ops.placeholder(dtypes.float32, shape=(2))
+ update_op = var.assign_add(feed)
+
+ sess1 = session.Session(server2.target, config=config1)
+ sess2 = session.Session(server2.target, config=config2)
+
+ variables.global_variables_initializer().run(session=sess1)
+ variables.global_variables_initializer().run(session=sess2)
+
+ expected_zeros = np.zeros([2])
+ expected_ones = np.ones([2])
+
+ self.assertAllEqual(expected_zeros, sess1.run(var))
+ self.assertAllEqual(expected_zeros, sess2.run(var))
+ self.assertAllEqual(expected_ones,
+ sess1.run(update_op, feed_dict={feed: expected_ones}))
+ self.assertAllEqual(expected_ones, sess1.run(var))
+ self.assertAllEqual(expected_zeros, sess2.run(var))
+ self.assertAllEqual(expected_ones,
+ sess2.run(update_op, feed_dict={feed: expected_ones}))
+ self.assertAllEqual(expected_ones + expected_ones,
+ sess1.run(update_op, feed_dict={feed: expected_ones}))
+ self.assertAllEqual(expected_ones, sess2.run(var))
+ self.assertAllEqual(expected_ones + expected_ones, sess1.run(var))
+
+ def testClusterSpecPropagationThreeServersOneCluster(self):
+ """Boots 3 servers, ensures appropriate communication across workers.
+
+ Additionally, in this cluster, we ensure the master is not the 0-th worker.
+
+ Note: this test only uses one session.
+ """
+ server1 = server_lib.Server.create_local_server()
+ server2 = server_lib.Server.create_local_server()
+ server3 = server_lib.Server.create_local_server()
+ cluster_def = cluster_pb2.ClusterDef()
+ job = cluster_def.job.add()
+ job.name = 'worker'
+ job.tasks[0] = server3.target[len('grpc://'):]
+ job.tasks[1] = server2.target[len('grpc://'):]
+ job.tasks[2] = server1.target[len('grpc://'):]
+ config = config_pb2.ConfigProto(cluster_def=cluster_def)
+
+ # Add ops to the devices in non-linear order.
+
+ with ops.device('/job:worker/task:1'):
+ feed1 = array_ops.placeholder(dtypes.float32, shape=(2))
+ const1 = constant_op.constant(2.0)
+ mul1 = const1 * feed1
+
+ with ops.device('/job:worker/task:2'):
+ feed2 = array_ops.placeholder(dtypes.float32, shape=(2))
+ const2 = constant_op.constant(2.0)
+ mul2 = const2 * feed2
+
+ with ops.device('/job:worker/task:0'):
+ feed0 = array_ops.placeholder(dtypes.float32, shape=(2))
+ const0 = constant_op.constant(2.0)
+ mul0 = const0 * feed0
+
+ sum_op = mul0 + mul1 + mul2
+
+ ones = np.ones([2])
+ run_options = config_pb2.RunOptions(
+ trace_level=config_pb2.RunOptions.FULL_TRACE)
+ run_metadata = config_pb2.RunMetadata()
+
+ # Run!
+ with session.Session(server1.target, config=config) as sess:
+ output = sess.run(
+ sum_op,
+ options=run_options,
+ run_metadata=run_metadata,
+ feed_dict={feed1: ones,
+ feed2: ones,
+ feed0: ones})
+ self.assertAllEqual(6 * ones, output)
+
+ self.assertEqual(
+ 3,
+ len([
+ dev_stats.device
+ for dev_stats in run_metadata.step_stats.dev_stats
+ for node_stats in dev_stats.node_stats
+ if '/job:worker/replica:0/task:' in dev_stats.device and
+ node_stats.node_name.startswith('Const')
+ ]), run_metadata)
+
+ def testClusterSpecPropagationPartialRun(self):
+ """Test successful partial run with ClusterSpec propagation."""
+ server1 = server_lib.Server.create_local_server()
+ server2 = server_lib.Server.create_local_server()
+
+ cluster_def = cluster_pb2.ClusterDef()
+ job = cluster_def.job.add()
+ job.name = 'worker'
+ job.tasks[0] = server1.target[len('grpc://'):]
+ job.tasks[1] = server2.target[len('grpc://'):]
+ config = config_pb2.ConfigProto(cluster_def=cluster_def)
+
+ with ops.device('/job:worker/task:0'):
+ a = array_ops.placeholder(dtypes.float32, shape=[])
+ with ops.device('/job:worker/task:1'):
+ b = array_ops.placeholder(dtypes.float32, shape=[])
+ c = array_ops.placeholder(dtypes.float32, shape=[])
+ r1 = math_ops.add(a, b)
+ with ops.device('/job:worker/task:0'):
+ r2 = math_ops.multiply(r1, c)
+
+ with session.Session(server1.target, config=config) as sess:
+ h = sess.partial_run_setup([r1, r2], [a, b, c])
+ res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2})
+ self.assertEqual(3, res)
+ res = sess.partial_run(h, r2, feed_dict={c: 3})
+ self.assertEqual(9, res)
+
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/python/training/server_lib.py b/tensorflow/python/training/server_lib.py
index d2ccf37d88..2091eca0b9 100644
--- a/tensorflow/python/training/server_lib.py
+++ b/tensorflow/python/training/server_lib.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.core.protobuf import cluster_pb2
from tensorflow.core.protobuf import tensorflow_server_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.framework import errors
@@ -276,14 +277,14 @@ class ClusterSpec(object):
"from integers to strings." % job_name)
self._cluster_spec[job_name] = job_tasks
self._make_cluster_def()
- elif isinstance(cluster, tensorflow_server_pb2.ClusterDef):
+ elif isinstance(cluster, cluster_pb2.ClusterDef):
self._cluster_def = cluster
self._cluster_spec = {}
for job_def in self._cluster_def.job:
self._cluster_spec[job_def.name] = {
i: t for i, t in job_def.tasks.items()}
elif isinstance(cluster, ClusterSpec):
- self._cluster_def = tensorflow_server_pb2.ClusterDef()
+ self._cluster_def = cluster_pb2.ClusterDef()
self._cluster_def.MergeFrom(cluster.as_cluster_def())
self._cluster_spec = {}
for job_def in self._cluster_def.job:
@@ -440,7 +441,7 @@ class ClusterSpec(object):
TypeError: If `cluster_spec` is not a dictionary mapping strings to lists
of strings.
"""
- self._cluster_def = tensorflow_server_pb2.ClusterDef()
+ self._cluster_def = cluster_pb2.ClusterDef()
# NOTE(mrry): Sort by job_name to produce deterministic protobufs.
for job_name, tasks in sorted(self._cluster_spec.items()):
diff --git a/tensorflow/python/training/training.py b/tensorflow/python/training/training.py
index bdf3d9c017..f4ac3c9758 100644
--- a/tensorflow/python/training/training.py
+++ b/tensorflow/python/training/training.py
@@ -186,8 +186,8 @@ from tensorflow.python.training.learning_rate_decay import *
# pylint: enable=wildcard-import
# Distributed computing support.
-from tensorflow.core.protobuf.tensorflow_server_pb2 import ClusterDef
-from tensorflow.core.protobuf.tensorflow_server_pb2 import JobDef
+from tensorflow.core.protobuf.cluster_pb2 import ClusterDef
+from tensorflow.core.protobuf.cluster_pb2 import JobDef
from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef
from tensorflow.python.training.server_lib import ClusterSpec
from tensorflow.python.training.server_lib import Server
@@ -196,32 +196,32 @@ from tensorflow.python.training.server_lib import Server
_allowed_symbols = [
# TODO(cwhipkey): review these and move to contrib or expose through
# documentation.
- "generate_checkpoint_state_proto", # Used internally by saver.
+ "generate_checkpoint_state_proto", # Used internally by saver.
"checkpoint_exists", # Only used in test?
"get_checkpoint_mtimes", # Only used in test?
# Legacy: remove.
"do_quantize_training_on_graphdef", # At least use grah_def, not graphdef.
- # No uses within tensorflow.
+ # No uses within tensorflow.
"queue_runner", # Use tf.train.start_queue_runner etc directly.
- # This is also imported internally.
+ # This is also imported internally.
# TODO(drpng): document these. The reference in howtos/distributed does
# not link.
"SyncReplicasOptimizer",
# Protobufs:
- "BytesList", # from example_pb2.
+ "BytesList", # from example_pb2.
"ClusterDef",
- "Example", # from example_pb2
- "Feature", # from example_pb2
- "Features", # from example_pb2
- "FeatureList", # from example_pb2
- "FeatureLists", # from example_pb2
- "FloatList", # from example_pb2.
- "Int64List", # from example_pb2.
+ "Example", # from example_pb2
+ "Feature", # from example_pb2
+ "Features", # from example_pb2
+ "FeatureList", # from example_pb2
+ "FeatureLists", # from example_pb2
+ "FloatList", # from example_pb2.
+ "Int64List", # from example_pb2.
"JobDef",
- "SaverDef", # From saver_pb2.
- "SequenceExample", # from example_pb2.
+ "SaverDef", # From saver_pb2.
+ "SequenceExample", # from example_pb2.
"ServerDef",
]
# Include extra modules for docstrings because:
diff --git a/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt b/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt
index 805a9bdd4f..da6af3919e 100644
--- a/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt
@@ -7,6 +7,10 @@ tf_class {
mtype: "<type \'int\'>"
}
member {
+ name: "CLUSTER_DEF_FIELD_NUMBER"
+ mtype: "<type \'int\'>"
+ }
+ member {
name: "DESCRIPTOR"
mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
}
diff --git a/tensorflow/tools/api/golden/tensorflow.train.-cluster-def.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-cluster-def.pbtxt
index feb73bd7d4..93ff856b09 100644
--- a/tensorflow/tools/api/golden/tensorflow.train.-cluster-def.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.train.-cluster-def.pbtxt
@@ -1,6 +1,6 @@
path: "tensorflow.train.ClusterDef"
tf_class {
- is_instance: "<class \'tensorflow.core.protobuf.tensorflow_server_pb2.ClusterDef\'>"
+ is_instance: "<class \'tensorflow.core.protobuf.cluster_pb2.ClusterDef\'>"
is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
member {
name: "DESCRIPTOR"
diff --git a/tensorflow/tools/api/golden/tensorflow.train.-job-def.-tasks-entry.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-job-def.-tasks-entry.pbtxt
index 2d7fcbe545..ac6d81541a 100644
--- a/tensorflow/tools/api/golden/tensorflow.train.-job-def.-tasks-entry.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.train.-job-def.-tasks-entry.pbtxt
@@ -1,6 +1,6 @@
path: "tensorflow.train.JobDef.TasksEntry"
tf_class {
- is_instance: "<class \'tensorflow.core.protobuf.tensorflow_server_pb2.TasksEntry\'>"
+ is_instance: "<class \'tensorflow.core.protobuf.cluster_pb2.TasksEntry\'>"
is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
member {
name: "DESCRIPTOR"
diff --git a/tensorflow/tools/api/golden/tensorflow.train.-job-def.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-job-def.pbtxt
index fc5b76341d..ce34537fa1 100644
--- a/tensorflow/tools/api/golden/tensorflow.train.-job-def.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.train.-job-def.pbtxt
@@ -1,6 +1,6 @@
path: "tensorflow.train.JobDef"
tf_class {
- is_instance: "<class \'tensorflow.core.protobuf.tensorflow_server_pb2.JobDef\'>"
+ is_instance: "<class \'tensorflow.core.protobuf.cluster_pb2.JobDef\'>"
is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
member {
name: "DESCRIPTOR"