aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/cluster_resolver
diff options
context:
space:
mode:
authorGravatar Frank Chen <frankchn@google.com>2017-06-07 22:38:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-07 22:43:36 -0700
commit3a2971bd8e2277bd6a32bd222852952b57b11fc4 (patch)
treeceb9461791e793baf3d9290ffda28bb9721a4cf5 /tensorflow/contrib/cluster_resolver
parentcd5ac40b31afaec237aaee35007f2dc846caf811 (diff)
Adds the base for ClusterResolvers, a new way of communicating with and retrieving cluster information for running distributed TensorFlow.
Implementations of this class would eventually allow users to simply point TensorFlow at a cluster management endpoint, and TensorFlow will automatically retrieve the host names/IPs and port numbers of TensorFlow workers from the cluster management service. PiperOrigin-RevId: 158358761
Diffstat (limited to 'tensorflow/contrib/cluster_resolver')
-rw-r--r--tensorflow/contrib/cluster_resolver/BUILD47
-rw-r--r--tensorflow/contrib/cluster_resolver/README.md5
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/__init__.py23
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py171
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py238
5 files changed, 484 insertions, 0 deletions
diff --git a/tensorflow/contrib/cluster_resolver/BUILD b/tensorflow/contrib/cluster_resolver/BUILD
new file mode 100644
index 0000000000..34cdb2a132
--- /dev/null
+++ b/tensorflow/contrib/cluster_resolver/BUILD
@@ -0,0 +1,47 @@
+# Description: Operations defined for Cluster Resolvers
+
+load("//tensorflow:tensorflow.bzl", "tf_py_test")
+
+package(
+ default_visibility = [
+ "//tensorflow:__subpackages__",
+ ],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+)
+
+py_library(
+ name = "cluster_resolver_py",
+ srcs = [
+ "python/training/__init__.py",
+ "python/training/cluster_resolver.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:framework",
+ ],
+)
+
+tf_py_test(
+ name = "cluster_resolver_py_test",
+ srcs = ["python/training/cluster_resolver_test.py"],
+ additional_deps = [
+ ":cluster_resolver_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+ main = "python/training/cluster_resolver_test.py",
+)
diff --git a/tensorflow/contrib/cluster_resolver/README.md b/tensorflow/contrib/cluster_resolver/README.md
new file mode 100644
index 0000000000..6fe6871eb4
--- /dev/null
+++ b/tensorflow/contrib/cluster_resolver/README.md
@@ -0,0 +1,5 @@
+# Cluster Resolvers
+
+Cluster Resolvers are a new way of specifying cluster information for distributed execution. Built on top of existing `ClusterSpec` framework, Cluster Resolvers allow users to simply specify a configuration and a cluster management service and a `ClusterResolver` will automatically fetch the relevant information from the service and populate `ClusterSpec`s.
+
+`ClusterResolvers` are designed to work well with `ManagedTrainingSession` and `ClusterSpec` propagation so that distributed training sessions remain robust in the face of node and network failures.
diff --git a/tensorflow/contrib/cluster_resolver/python/training/__init__.py b/tensorflow/contrib/cluster_resolver/python/training/__init__.py
new file mode 100644
index 0000000000..3520467bc6
--- /dev/null
+++ b/tensorflow/contrib/cluster_resolver/python/training/__init__.py
@@ -0,0 +1,23 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Library Imports for Cluster Resolvers."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver
+from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import SimpleClusterResolver
+from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import UnionClusterResolver
diff --git a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py
new file mode 100644
index 0000000000..87da24f22d
--- /dev/null
+++ b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py
@@ -0,0 +1,171 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Cluster Resolvers are used for dynamic cluster IP/hostname resolution."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+
+from tensorflow.python.training.server_lib import ClusterSpec
+
+
+class ClusterResolver(object):
+ """Abstract class for all implementations of ClusterResolvers.
+
+ This defines the skeleton for all implementations of ClusterResolvers.
+ ClusterResolvers are a way for TensorFlow to communicate with various cluster
+ management systems (e.g. GCE, AWS, etc...).
+
+ By letting TensorFlow communicate with these systems, we will be able to
+ automatically discover and resolve IP addresses for various TensorFlow
+ workers. This will eventually allow us to automatically recover from
+ underlying machine failures and scale TensorFlow worker clusters up and down.
+ """
+
+ @abc.abstractmethod
+ def cluster_spec(self):
+ """Retrieve the current state of the cluster and returns a ClusterSpec.
+
+ Returns:
+ A ClusterSpec representing the state of the cluster at the moment this
+ function is called.
+
+ Implementors of this function must take care in ensuring that the
+ ClusterSpec returned is up-to-date at the time of calling this function.
+ This usually means retrieving the information from the underlying cluster
+ management system every time this function is invoked and reconstructing
+ a cluster_spec, rather than attempting to cache anything.
+ """
+ raise NotImplementedError(
+ 'cluster_spec is not implemented for {}.'.format(self))
+
+
+class SimpleClusterResolver(ClusterResolver):
+ """Simple implementation of ClusterResolver that accepts a ClusterSpec."""
+
+ def __init__(self, cluster_spec):
+ """Creates a SimpleClusterResolver from a ClusterSpec."""
+ super(SimpleClusterResolver, self).__init__()
+
+ if not isinstance(cluster_spec, ClusterSpec):
+ raise TypeError('cluster_spec must be a ClusterSpec.')
+ self._cluster_spec = cluster_spec
+
+ def cluster_spec(self):
+ """Returns the ClusterSpec passed into the constructor."""
+ return self._cluster_spec
+
+
+class UnionClusterResolver(ClusterResolver):
+ """Performs a union on underlying ClusterResolvers.
+
+ This class performs a union given two or more existing ClusterResolvers. It
+ merges the underlying ClusterResolvers, and returns one unified ClusterSpec
+ when as_cluster_spec is called. The details of the merge function is
+ documented in the as_cluster_spec function.
+ """
+
+ def __init__(self, *args):
+ """Initializes a UnionClusterResolver with other ClusterResolvers.
+
+ Args:
+ *args: `ClusterResolver` objects to be unionized.
+
+ Raises:
+ TypeError: If any argument is not a subclass of `ClusterResolvers`.
+ """
+ super(UnionClusterResolver, self).__init__()
+
+ for cluster_resolver in args:
+ if not isinstance(cluster_resolver, ClusterResolver):
+ raise TypeError('All arguments must be a sub-class of '
+ '`ClusterResolver.`')
+ self._cluster_resolvers = args
+
+ def cluster_spec(self):
+ """Returns a union of all the ClusterSpecs from the ClusterResolvers.
+
+ Returns:
+ A ClusterSpec containing host information merged from all the underlying
+ ClusterResolvers.
+
+ Raises:
+ KeyError: If there are conflicting keys detected when merging two or
+ more dictionaries, this exception is raised.
+
+ Note: If there are multiple ClusterResolvers exposing ClusterSpecs with the
+ same job name, we will merge the list/dict of workers.
+
+ If *all* underlying ClusterSpecs expose the set of workers as lists, we will
+ concatenate the lists of workers, starting with the list of workers from
+ the first ClusterResolver passed into the constructor.
+
+ If *any* of the ClusterSpecs expose the set of workers as a dict, we will
+ treat all the sets of workers as dicts (even if they are returned as lists)
+ and will only merge them into a dict if there is no conflicting keys. If
+ there is a conflicting key, we will raise a `KeyError`.
+ """
+
+ merged_cluster = {}
+
+ # We figure out whether it is all lists for a particular job, or whether
+ # there are dicts inside.
+ for cluster_resolver in self._cluster_resolvers:
+ cluster_spec = cluster_resolver.cluster_spec()
+ cluster_dict = cluster_spec.as_dict()
+
+ for job_name, tasks in cluster_dict.items():
+ if job_name in merged_cluster:
+ # If we see a dict, then we write a dict out regardless.
+ if isinstance(tasks, dict):
+ merged_cluster[job_name] = {}
+ else:
+ # We take whichever type is present.
+ if isinstance(tasks, list):
+ merged_cluster[job_name] = []
+ else:
+ merged_cluster[job_name] = {}
+
+ # We then do the merge as appropriate in merged_cluster[job].
+ for cluster_resolver in self._cluster_resolvers:
+ cluster_spec = cluster_resolver.cluster_spec()
+ cluster_dict = cluster_spec.as_dict()
+
+ for job_name, tasks in cluster_dict.items():
+ if isinstance(merged_cluster[job_name], list):
+ # We all have lists, we can just concatenate and be done.
+ merged_cluster[job_name].extend(tasks)
+ else:
+ if isinstance(tasks, list):
+ # We convert to a dictionary if the type is a list.
+ task_dict = dict(zip(range(0, len(tasks)), tasks))
+ else:
+ # We can simply make a copy (for update) and be done.
+ task_dict = tasks.copy()
+
+ # We detect if there are duplicates, and raise an error if so.
+ task_keys = set(task_dict)
+ merged_keys = set(merged_cluster[job_name].keys())
+ intersected_keys = task_keys.intersection(merged_keys)
+ if intersected_keys:
+ raise KeyError('Duplicate keys detected when merging two '
+ 'ClusterSpecs: %s' % repr(intersected_keys))
+
+ # We do the merge after all the processing.
+ merged_cluster[job_name].update(task_dict)
+
+ return ClusterSpec(merged_cluster)
diff --git a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py
new file mode 100644
index 0000000000..dbfb77723c
--- /dev/null
+++ b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py
@@ -0,0 +1,238 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Cluster Resolvers."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import SimpleClusterResolver
+from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import UnionClusterResolver
+from tensorflow.python.platform import test
+from tensorflow.python.training import server_lib
+
+
+class UnionClusterResolverTest(test.TestCase):
+ # TODO(frankchn): Transform to parameterized test after it is included in the
+ # TF open source codebase.
+
+ def _verifyClusterSpecEquality(self, cluster_spec, expected_proto):
+ self.assertProtoEquals(expected_proto, cluster_spec.as_cluster_def())
+ self.assertProtoEquals(
+ expected_proto, server_lib.ClusterSpec(cluster_spec).as_cluster_def())
+ self.assertProtoEquals(
+ expected_proto,
+ server_lib.ClusterSpec(cluster_spec.as_cluster_def()).as_cluster_def())
+ self.assertProtoEquals(
+ expected_proto,
+ server_lib.ClusterSpec(cluster_spec.as_dict()).as_cluster_def())
+
+ def testSingleClusterResolver(self):
+ base_cluster_spec = server_lib.ClusterSpec({
+ "ps": ["ps0:2222", "ps1:2222"],
+ "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
+ })
+ simple_resolver = SimpleClusterResolver(base_cluster_spec)
+ union_resolver = UnionClusterResolver(simple_resolver)
+
+ expected_proto = """
+ job { name: 'ps' tasks { key: 0 value: 'ps0:2222' }
+ tasks { key: 1 value: 'ps1:2222' } }
+ job { name: 'worker' tasks { key: 0 value: 'worker0:2222' }
+ tasks { key: 1 value: 'worker1:2222' }
+ tasks { key: 2 value: 'worker2:2222' } }
+ """
+ actual_cluster_spec = union_resolver.cluster_spec()
+ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
+
+ def testTwoNonOverlappingJobMergedClusterResolver(self):
+ cluster_spec_1 = server_lib.ClusterSpec({
+ "ps": [
+ "ps0:2222",
+ "ps1:2222"
+ ]
+ })
+ cluster_spec_2 = server_lib.ClusterSpec({
+ "worker": [
+ "worker0:2222",
+ "worker1:2222",
+ "worker2:2222"
+ ]
+ })
+ cluster_resolver_1 = SimpleClusterResolver(cluster_spec_1)
+ cluster_resolver_2 = SimpleClusterResolver(cluster_spec_2)
+
+ union_cluster = UnionClusterResolver(cluster_resolver_1, cluster_resolver_2)
+ cluster_spec = union_cluster.cluster_spec()
+
+ expected_proto = """
+ job { name: 'ps' tasks { key: 0 value: 'ps0:2222' }
+ tasks { key: 1 value: 'ps1:2222' } }
+ job { name: 'worker' tasks { key: 0 value: 'worker0:2222' }
+ tasks { key: 1 value: 'worker1:2222' }
+ tasks { key: 2 value: 'worker2:2222' } }
+ """
+ self._verifyClusterSpecEquality(cluster_spec, expected_proto)
+
+ def testOverlappingJobMergedClusterResolver(self):
+ cluster_spec_1 = server_lib.ClusterSpec({
+ "worker": [
+ "worker4:2222",
+ "worker5:2222"
+ ]
+ })
+ cluster_spec_2 = server_lib.ClusterSpec({
+ "worker": [
+ "worker0:2222",
+ "worker1:2222",
+ "worker2:2222"
+ ]
+ })
+ cluster_resolver_1 = SimpleClusterResolver(cluster_spec_1)
+ cluster_resolver_2 = SimpleClusterResolver(cluster_spec_2)
+
+ union_cluster = UnionClusterResolver(cluster_resolver_1, cluster_resolver_2)
+ cluster_spec = union_cluster.cluster_spec()
+
+ expected_proto = """
+ job { name: 'worker' tasks { key: 0 value: 'worker4:2222' }
+ tasks { key: 1 value: 'worker5:2222' }
+ tasks { key: 2 value: 'worker0:2222' }
+ tasks { key: 3 value: 'worker1:2222' }
+ tasks { key: 4 value: 'worker2:2222' } }
+ """
+ self._verifyClusterSpecEquality(cluster_spec, expected_proto)
+
+ def testOverlappingSparseJobMergedClusterResolverThrowError(self):
+ cluster_spec_1 = server_lib.ClusterSpec({
+ "worker": {
+ 7: "worker4:2222",
+ 9: "worker5:2222"
+ }
+ })
+ cluster_spec_2 = server_lib.ClusterSpec({
+ "worker": {
+ 3: "worker0:2222",
+ 6: "worker1:2222",
+ 7: "worker2:2222"
+ }
+ })
+ cluster_resolver_1 = SimpleClusterResolver(cluster_spec_1)
+ cluster_resolver_2 = SimpleClusterResolver(cluster_spec_2)
+
+ union_cluster = UnionClusterResolver(cluster_resolver_1, cluster_resolver_2)
+ self.assertRaises(KeyError, union_cluster.cluster_spec)
+
+ def testOverlappingDictAndListThrowError(self):
+ cluster_spec_1 = server_lib.ClusterSpec({
+ "worker": [
+ "worker4:2222",
+ "worker5:2222"
+ ]
+ })
+ cluster_spec_2 = server_lib.ClusterSpec({
+ "worker": {
+ 1: "worker0:2222",
+ 2: "worker1:2222",
+ 3: "worker2:2222"
+ }
+ })
+ cluster_resolver_1 = SimpleClusterResolver(cluster_spec_1)
+ cluster_resolver_2 = SimpleClusterResolver(cluster_spec_2)
+
+ union_cluster = UnionClusterResolver(cluster_resolver_1, cluster_resolver_2)
+ self.assertRaises(KeyError, union_cluster.cluster_spec)
+
+ def testOverlappingJobNonOverlappingKey(self):
+ cluster_spec_1 = server_lib.ClusterSpec({
+ "worker": {
+ 5: "worker4:2222",
+ 9: "worker5:2222"
+ }
+ })
+ cluster_spec_2 = server_lib.ClusterSpec({
+ "worker": {
+ 3: "worker0:2222",
+ 6: "worker1:2222",
+ 7: "worker2:2222"
+ }
+ })
+ cluster_resolver_1 = SimpleClusterResolver(cluster_spec_1)
+ cluster_resolver_2 = SimpleClusterResolver(cluster_spec_2)
+
+ union_cluster = UnionClusterResolver(cluster_resolver_1, cluster_resolver_2)
+ cluster_spec = union_cluster.cluster_spec()
+
+ expected_proto = """
+ job { name: 'worker' tasks { key: 3 value: 'worker0:2222' }
+ tasks { key: 5 value: 'worker4:2222' }
+ tasks { key: 6 value: 'worker1:2222' }
+ tasks { key: 7 value: 'worker2:2222' }
+ tasks { key: 9 value: 'worker5:2222' }}
+ """
+ self._verifyClusterSpecEquality(cluster_spec, expected_proto)
+
+ def testMixedModeNonOverlappingKey(self):
+ cluster_spec_1 = server_lib.ClusterSpec({
+ "worker": [
+ "worker4:2222",
+ "worker5:2222"
+ ]
+ })
+ cluster_spec_2 = server_lib.ClusterSpec({
+ "worker": {
+ 3: "worker0:2222",
+ 6: "worker1:2222",
+ 7: "worker2:2222"
+ }
+ })
+ cluster_resolver_1 = SimpleClusterResolver(cluster_spec_1)
+ cluster_resolver_2 = SimpleClusterResolver(cluster_spec_2)
+
+ union_cluster = UnionClusterResolver(cluster_resolver_1, cluster_resolver_2)
+ cluster_spec = union_cluster.cluster_spec()
+
+ expected_proto = """
+ job { name: 'worker' tasks { key: 0 value: 'worker4:2222' }
+ tasks { key: 1 value: 'worker5:2222' }
+ tasks { key: 3 value: 'worker0:2222' }
+ tasks { key: 6 value: 'worker1:2222' }
+ tasks { key: 7 value: 'worker2:2222' }}
+ """
+ self._verifyClusterSpecEquality(cluster_spec, expected_proto)
+
+ def testRetainSparseJobWithNoMerging(self):
+ base_cluster_spec = server_lib.ClusterSpec({
+ "worker": {
+ 1: "worker0:2222",
+ 3: "worker1:2222",
+ 5: "worker2:2222"
+ }
+ })
+
+ base_cluster_resolver = SimpleClusterResolver(base_cluster_spec)
+ union_cluster = UnionClusterResolver(base_cluster_resolver)
+ cluster_spec = union_cluster.cluster_spec()
+
+ expected_proto = """
+ job { name: 'worker' tasks { key: 1 value: 'worker0:2222' }
+ tasks { key: 3 value: 'worker1:2222' }
+ tasks { key: 5 value: 'worker2:2222' } }
+ """
+ self._verifyClusterSpecEquality(cluster_spec, expected_proto)
+
+
+if __name__ == "__main__":
+ test.main()