diff options
author | Frank Chen <frankchn@google.com> | 2017-06-07 22:38:41 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-06-07 22:43:36 -0700 |
commit | 3a2971bd8e2277bd6a32bd222852952b57b11fc4 (patch) | |
tree | ceb9461791e793baf3d9290ffda28bb9721a4cf5 /tensorflow/contrib/cluster_resolver | |
parent | cd5ac40b31afaec237aaee35007f2dc846caf811 (diff) |
Adds the base for ClusterResolvers, a new way of communicating with and retrieving cluster information for running distributed TensorFlow.
Implementations of this class would eventually allow users to simply point TensorFlow at a cluster management endpoint, and TensorFlow will automatically retrieve the host names/IPs and port numbers of TensorFlow workers from the cluster management service.
PiperOrigin-RevId: 158358761
Diffstat (limited to 'tensorflow/contrib/cluster_resolver')
5 files changed, 484 insertions, 0 deletions
diff --git a/tensorflow/contrib/cluster_resolver/BUILD b/tensorflow/contrib/cluster_resolver/BUILD new file mode 100644 index 0000000000..34cdb2a132 --- /dev/null +++ b/tensorflow/contrib/cluster_resolver/BUILD @@ -0,0 +1,47 @@ +# Description: Operations defined for Cluster Resolvers + +load("//tensorflow:tensorflow.bzl", "tf_py_test") + +package( + default_visibility = [ + "//tensorflow:__subpackages__", + ], +) + +licenses(["notice"]) # Apache 2.0 + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), +) + +py_library( + name = "cluster_resolver_py", + srcs = [ + "python/training/__init__.py", + "python/training/cluster_resolver.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:framework", + ], +) + +tf_py_test( + name = "cluster_resolver_py_test", + srcs = ["python/training/cluster_resolver_test.py"], + additional_deps = [ + ":cluster_resolver_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], + main = "python/training/cluster_resolver_test.py", +) diff --git a/tensorflow/contrib/cluster_resolver/README.md b/tensorflow/contrib/cluster_resolver/README.md new file mode 100644 index 0000000000..6fe6871eb4 --- /dev/null +++ b/tensorflow/contrib/cluster_resolver/README.md @@ -0,0 +1,5 @@ +# Cluster Resolvers + +Cluster Resolvers are a new way of specifying cluster information for distributed execution. Built on top of existing `ClusterSpec` framework, Cluster Resolvers allow users to simply specify a configuration and a cluster management service and a `ClusterResolver` will automatically fetch the relevant information from the service and populate `ClusterSpec`s. + +`ClusterResolvers` are designed to work well with `ManagedTrainingSession` and `ClusterSpec` propagation so that distributed training sessions remain robust in the face of node and network failures. diff --git a/tensorflow/contrib/cluster_resolver/python/training/__init__.py b/tensorflow/contrib/cluster_resolver/python/training/__init__.py new file mode 100644 index 0000000000..3520467bc6 --- /dev/null +++ b/tensorflow/contrib/cluster_resolver/python/training/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Library Imports for Cluster Resolvers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver +from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import SimpleClusterResolver +from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import UnionClusterResolver diff --git a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py new file mode 100644 index 0000000000..87da24f22d --- /dev/null +++ b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py @@ -0,0 +1,171 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Cluster Resolvers are used for dynamic cluster IP/hostname resolution.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc + +from tensorflow.python.training.server_lib import ClusterSpec + + +class ClusterResolver(object): + """Abstract class for all implementations of ClusterResolvers. + + This defines the skeleton for all implementations of ClusterResolvers. + ClusterResolvers are a way for TensorFlow to communicate with various cluster + management systems (e.g. GCE, AWS, etc...). + + By letting TensorFlow communicate with these systems, we will be able to + automatically discover and resolve IP addresses for various TensorFlow + workers. This will eventually allow us to automatically recover from + underlying machine failures and scale TensorFlow worker clusters up and down. + """ + + @abc.abstractmethod + def cluster_spec(self): + """Retrieve the current state of the cluster and returns a ClusterSpec. + + Returns: + A ClusterSpec representing the state of the cluster at the moment this + function is called. + + Implementors of this function must take care in ensuring that the + ClusterSpec returned is up-to-date at the time of calling this function. + This usually means retrieving the information from the underlying cluster + management system every time this function is invoked and reconstructing + a cluster_spec, rather than attempting to cache anything. + """ + raise NotImplementedError( + 'cluster_spec is not implemented for {}.'.format(self)) + + +class SimpleClusterResolver(ClusterResolver): + """Simple implementation of ClusterResolver that accepts a ClusterSpec.""" + + def __init__(self, cluster_spec): + """Creates a SimpleClusterResolver from a ClusterSpec.""" + super(SimpleClusterResolver, self).__init__() + + if not isinstance(cluster_spec, ClusterSpec): + raise TypeError('cluster_spec must be a ClusterSpec.') + self._cluster_spec = cluster_spec + + def cluster_spec(self): + """Returns the ClusterSpec passed into the constructor.""" + return self._cluster_spec + + +class UnionClusterResolver(ClusterResolver): + """Performs a union on underlying ClusterResolvers. + + This class performs a union given two or more existing ClusterResolvers. It + merges the underlying ClusterResolvers, and returns one unified ClusterSpec + when as_cluster_spec is called. The details of the merge function is + documented in the as_cluster_spec function. + """ + + def __init__(self, *args): + """Initializes a UnionClusterResolver with other ClusterResolvers. + + Args: + *args: `ClusterResolver` objects to be unionized. + + Raises: + TypeError: If any argument is not a subclass of `ClusterResolvers`. + """ + super(UnionClusterResolver, self).__init__() + + for cluster_resolver in args: + if not isinstance(cluster_resolver, ClusterResolver): + raise TypeError('All arguments must be a sub-class of ' + '`ClusterResolver.`') + self._cluster_resolvers = args + + def cluster_spec(self): + """Returns a union of all the ClusterSpecs from the ClusterResolvers. + + Returns: + A ClusterSpec containing host information merged from all the underlying + ClusterResolvers. + + Raises: + KeyError: If there are conflicting keys detected when merging two or + more dictionaries, this exception is raised. + + Note: If there are multiple ClusterResolvers exposing ClusterSpecs with the + same job name, we will merge the list/dict of workers. + + If *all* underlying ClusterSpecs expose the set of workers as lists, we will + concatenate the lists of workers, starting with the list of workers from + the first ClusterResolver passed into the constructor. + + If *any* of the ClusterSpecs expose the set of workers as a dict, we will + treat all the sets of workers as dicts (even if they are returned as lists) + and will only merge them into a dict if there is no conflicting keys. If + there is a conflicting key, we will raise a `KeyError`. + """ + + merged_cluster = {} + + # We figure out whether it is all lists for a particular job, or whether + # there are dicts inside. + for cluster_resolver in self._cluster_resolvers: + cluster_spec = cluster_resolver.cluster_spec() + cluster_dict = cluster_spec.as_dict() + + for job_name, tasks in cluster_dict.items(): + if job_name in merged_cluster: + # If we see a dict, then we write a dict out regardless. + if isinstance(tasks, dict): + merged_cluster[job_name] = {} + else: + # We take whichever type is present. + if isinstance(tasks, list): + merged_cluster[job_name] = [] + else: + merged_cluster[job_name] = {} + + # We then do the merge as appropriate in merged_cluster[job]. + for cluster_resolver in self._cluster_resolvers: + cluster_spec = cluster_resolver.cluster_spec() + cluster_dict = cluster_spec.as_dict() + + for job_name, tasks in cluster_dict.items(): + if isinstance(merged_cluster[job_name], list): + # We all have lists, we can just concatenate and be done. + merged_cluster[job_name].extend(tasks) + else: + if isinstance(tasks, list): + # We convert to a dictionary if the type is a list. + task_dict = dict(zip(range(0, len(tasks)), tasks)) + else: + # We can simply make a copy (for update) and be done. + task_dict = tasks.copy() + + # We detect if there are duplicates, and raise an error if so. + task_keys = set(task_dict) + merged_keys = set(merged_cluster[job_name].keys()) + intersected_keys = task_keys.intersection(merged_keys) + if intersected_keys: + raise KeyError('Duplicate keys detected when merging two ' + 'ClusterSpecs: %s' % repr(intersected_keys)) + + # We do the merge after all the processing. + merged_cluster[job_name].update(task_dict) + + return ClusterSpec(merged_cluster) diff --git a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py new file mode 100644 index 0000000000..dbfb77723c --- /dev/null +++ b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py @@ -0,0 +1,238 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Cluster Resolvers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import SimpleClusterResolver +from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import UnionClusterResolver +from tensorflow.python.platform import test +from tensorflow.python.training import server_lib + + +class UnionClusterResolverTest(test.TestCase): + # TODO(frankchn): Transform to parameterized test after it is included in the + # TF open source codebase. + + def _verifyClusterSpecEquality(self, cluster_spec, expected_proto): + self.assertProtoEquals(expected_proto, cluster_spec.as_cluster_def()) + self.assertProtoEquals( + expected_proto, server_lib.ClusterSpec(cluster_spec).as_cluster_def()) + self.assertProtoEquals( + expected_proto, + server_lib.ClusterSpec(cluster_spec.as_cluster_def()).as_cluster_def()) + self.assertProtoEquals( + expected_proto, + server_lib.ClusterSpec(cluster_spec.as_dict()).as_cluster_def()) + + def testSingleClusterResolver(self): + base_cluster_spec = server_lib.ClusterSpec({ + "ps": ["ps0:2222", "ps1:2222"], + "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] + }) + simple_resolver = SimpleClusterResolver(base_cluster_spec) + union_resolver = UnionClusterResolver(simple_resolver) + + expected_proto = """ + job { name: 'ps' tasks { key: 0 value: 'ps0:2222' } + tasks { key: 1 value: 'ps1:2222' } } + job { name: 'worker' tasks { key: 0 value: 'worker0:2222' } + tasks { key: 1 value: 'worker1:2222' } + tasks { key: 2 value: 'worker2:2222' } } + """ + actual_cluster_spec = union_resolver.cluster_spec() + self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + + def testTwoNonOverlappingJobMergedClusterResolver(self): + cluster_spec_1 = server_lib.ClusterSpec({ + "ps": [ + "ps0:2222", + "ps1:2222" + ] + }) + cluster_spec_2 = server_lib.ClusterSpec({ + "worker": [ + "worker0:2222", + "worker1:2222", + "worker2:2222" + ] + }) + cluster_resolver_1 = SimpleClusterResolver(cluster_spec_1) + cluster_resolver_2 = SimpleClusterResolver(cluster_spec_2) + + union_cluster = UnionClusterResolver(cluster_resolver_1, cluster_resolver_2) + cluster_spec = union_cluster.cluster_spec() + + expected_proto = """ + job { name: 'ps' tasks { key: 0 value: 'ps0:2222' } + tasks { key: 1 value: 'ps1:2222' } } + job { name: 'worker' tasks { key: 0 value: 'worker0:2222' } + tasks { key: 1 value: 'worker1:2222' } + tasks { key: 2 value: 'worker2:2222' } } + """ + self._verifyClusterSpecEquality(cluster_spec, expected_proto) + + def testOverlappingJobMergedClusterResolver(self): + cluster_spec_1 = server_lib.ClusterSpec({ + "worker": [ + "worker4:2222", + "worker5:2222" + ] + }) + cluster_spec_2 = server_lib.ClusterSpec({ + "worker": [ + "worker0:2222", + "worker1:2222", + "worker2:2222" + ] + }) + cluster_resolver_1 = SimpleClusterResolver(cluster_spec_1) + cluster_resolver_2 = SimpleClusterResolver(cluster_spec_2) + + union_cluster = UnionClusterResolver(cluster_resolver_1, cluster_resolver_2) + cluster_spec = union_cluster.cluster_spec() + + expected_proto = """ + job { name: 'worker' tasks { key: 0 value: 'worker4:2222' } + tasks { key: 1 value: 'worker5:2222' } + tasks { key: 2 value: 'worker0:2222' } + tasks { key: 3 value: 'worker1:2222' } + tasks { key: 4 value: 'worker2:2222' } } + """ + self._verifyClusterSpecEquality(cluster_spec, expected_proto) + + def testOverlappingSparseJobMergedClusterResolverThrowError(self): + cluster_spec_1 = server_lib.ClusterSpec({ + "worker": { + 7: "worker4:2222", + 9: "worker5:2222" + } + }) + cluster_spec_2 = server_lib.ClusterSpec({ + "worker": { + 3: "worker0:2222", + 6: "worker1:2222", + 7: "worker2:2222" + } + }) + cluster_resolver_1 = SimpleClusterResolver(cluster_spec_1) + cluster_resolver_2 = SimpleClusterResolver(cluster_spec_2) + + union_cluster = UnionClusterResolver(cluster_resolver_1, cluster_resolver_2) + self.assertRaises(KeyError, union_cluster.cluster_spec) + + def testOverlappingDictAndListThrowError(self): + cluster_spec_1 = server_lib.ClusterSpec({ + "worker": [ + "worker4:2222", + "worker5:2222" + ] + }) + cluster_spec_2 = server_lib.ClusterSpec({ + "worker": { + 1: "worker0:2222", + 2: "worker1:2222", + 3: "worker2:2222" + } + }) + cluster_resolver_1 = SimpleClusterResolver(cluster_spec_1) + cluster_resolver_2 = SimpleClusterResolver(cluster_spec_2) + + union_cluster = UnionClusterResolver(cluster_resolver_1, cluster_resolver_2) + self.assertRaises(KeyError, union_cluster.cluster_spec) + + def testOverlappingJobNonOverlappingKey(self): + cluster_spec_1 = server_lib.ClusterSpec({ + "worker": { + 5: "worker4:2222", + 9: "worker5:2222" + } + }) + cluster_spec_2 = server_lib.ClusterSpec({ + "worker": { + 3: "worker0:2222", + 6: "worker1:2222", + 7: "worker2:2222" + } + }) + cluster_resolver_1 = SimpleClusterResolver(cluster_spec_1) + cluster_resolver_2 = SimpleClusterResolver(cluster_spec_2) + + union_cluster = UnionClusterResolver(cluster_resolver_1, cluster_resolver_2) + cluster_spec = union_cluster.cluster_spec() + + expected_proto = """ + job { name: 'worker' tasks { key: 3 value: 'worker0:2222' } + tasks { key: 5 value: 'worker4:2222' } + tasks { key: 6 value: 'worker1:2222' } + tasks { key: 7 value: 'worker2:2222' } + tasks { key: 9 value: 'worker5:2222' }} + """ + self._verifyClusterSpecEquality(cluster_spec, expected_proto) + + def testMixedModeNonOverlappingKey(self): + cluster_spec_1 = server_lib.ClusterSpec({ + "worker": [ + "worker4:2222", + "worker5:2222" + ] + }) + cluster_spec_2 = server_lib.ClusterSpec({ + "worker": { + 3: "worker0:2222", + 6: "worker1:2222", + 7: "worker2:2222" + } + }) + cluster_resolver_1 = SimpleClusterResolver(cluster_spec_1) + cluster_resolver_2 = SimpleClusterResolver(cluster_spec_2) + + union_cluster = UnionClusterResolver(cluster_resolver_1, cluster_resolver_2) + cluster_spec = union_cluster.cluster_spec() + + expected_proto = """ + job { name: 'worker' tasks { key: 0 value: 'worker4:2222' } + tasks { key: 1 value: 'worker5:2222' } + tasks { key: 3 value: 'worker0:2222' } + tasks { key: 6 value: 'worker1:2222' } + tasks { key: 7 value: 'worker2:2222' }} + """ + self._verifyClusterSpecEquality(cluster_spec, expected_proto) + + def testRetainSparseJobWithNoMerging(self): + base_cluster_spec = server_lib.ClusterSpec({ + "worker": { + 1: "worker0:2222", + 3: "worker1:2222", + 5: "worker2:2222" + } + }) + + base_cluster_resolver = SimpleClusterResolver(base_cluster_spec) + union_cluster = UnionClusterResolver(base_cluster_resolver) + cluster_spec = union_cluster.cluster_spec() + + expected_proto = """ + job { name: 'worker' tasks { key: 1 value: 'worker0:2222' } + tasks { key: 3 value: 'worker1:2222' } + tasks { key: 5 value: 'worker2:2222' } } + """ + self._verifyClusterSpecEquality(cluster_spec, expected_proto) + + +if __name__ == "__main__": + test.main() |