diff options
author | Derek Murray <mrry@google.com> | 2016-08-27 08:01:08 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-08-27 09:17:43 -0700 |
commit | 8177edd7700ccbe0c831e680a1acb5275819d762 (patch) | |
tree | 895d81300a231d278c990ff7c08c9c8d1133f1c4 /tensorflow/core/distributed_runtime/rpc/grpc_channel.h | |
parent | f409f4f6ac23c182053595db7a2f0f490002005c (diff) |
Support sparse jobs for TensorFlow gRPC servers.
A TensorFlow server (tf.train.Server) is configured with a list of
jobs, where each job includes the addresses of the tasks in that
job. At present, the tasks are provided as a dense list, and a server
must be configured with the addresses of all tasks in every job, even
when that server might never contact a particular task.
This CL adds support for configuring individual jobs with a sparse
mapping from task index to network address. The net effect is that a
server in (e.g.) a worker job need not know the addresses of the other
worker tasks. This reduces the amount of configuration needed in two
ways: (i) the cluster specification for an individual server contains
only the server with which it makes contact, and (ii) there is no need
to specify a device filter to prevent the server pinging all known
tasks on session creation (which can lead to unavailability when
unrelated tasks fail).
This CL also cleans up the code in grpc_channel.{cc,h} in three ways:
1. Move unnecessarily public methods into an anonymous namespace.
2. Shorten some of the unwieldy function and class names.
3. Use std::move() where appropriate to avoid copying vectors and maps
of strings.
Change: 131490850
Diffstat (limited to 'tensorflow/core/distributed_runtime/rpc/grpc_channel.h')
-rw-r--r-- | tensorflow/core/distributed_runtime/rpc/grpc_channel.h | 30 |
1 files changed, 10 insertions, 20 deletions
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_channel.h b/tensorflow/core/distributed_runtime/rpc/grpc_channel.h index 32e6806ad0..0c0779440e 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_channel.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_channel.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CHANNEL_H_ #define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CHANNEL_H_ +#include <map> #include <memory> #include <set> #include <string> @@ -35,14 +36,17 @@ namespace tensorflow { class GrpcChannelSpec { public: struct HostPortsJob { - string job_id; - std::vector<string> host_ports; - int tasks_per_replica; + HostPortsJob(const string& job_id, const std::map<int, string>& host_ports) + : job_id(job_id), host_ports(host_ports) {} + const string job_id; + const std::map<int, string> host_ports; }; Status AddHostPortsJob(const string& job_id, - const std::vector<string>& host_ports, - int tasks_per_replica); + const std::vector<string>& host_ports); + + Status AddHostPortsJob(const string& job_id, + const std::map<int, string>& host_ports); const std::vector<HostPortsJob>& host_ports_jobs() const { return host_ports_jobs_; @@ -75,27 +79,13 @@ class GrpcChannelCache { typedef std::function<SharedGrpcChannelPtr(string)> ChannelCreationFunction; -GrpcChannelCache* NewGrpcChannelCache(const GrpcChannelSpec& p, +GrpcChannelCache* NewGrpcChannelCache(const GrpcChannelSpec& channel_spec, ChannelCreationFunction channel_func); // Below here are internal-only functions. SharedGrpcChannelPtr NewHostPortGrpcChannel(const string& target); -// Returns a ChannelCache that uses a set of known host:port pairs. E.g., say, -// job_id = 'mnist', 'host_ports' = {"h0:0", "h1:1", ..., "h11:11", "h12:12"}, -// tasks_per_replica = 8, /job:mnist/replica:1/task:3 is mapped to host:port -// "h11:11" (11 = 8 * 1 + 3). -// -// The caller takes ownership of the returned object. -GrpcChannelCache* NewHostPortsGrpcChannelCache( - const string& job_id, const std::vector<string>& host_ports, - int tasks_per_replica, ChannelCreationFunction channel_func); - -// Returns a ChannelCache that is the union of a number of other ChannelCaches. -GrpcChannelCache* NewMultiGrpcChannelCache( - const std::vector<GrpcChannelCache*>& caches); - } // namespace tensorflow #endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CHANNEL_H_ |