diff options
author | Frank Chen <frankchn@google.com> | 2017-07-10 14:46:57 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-07-10 14:50:48 -0700 |
commit | 8b53854ea5e4b7ebeec4a8636580f54424cb1952 (patch) | |
tree | f0c69365fe76329d9376e2e82243fc13b0525cc2 /tensorflow/contrib/cluster_resolver | |
parent | e121535a7d04cfc7c7dbb09d8694c01eb29da26f (diff) |
Adds support for retrieving instances from the Google Compute Engine instance group APIs, with support (in conjunction with UnionClusterResolver) for mapping multiple instance groups into one TensorFlow job (see the `testUnionMultipleInstanceRetrieval` test for details).
This should simplify creating and using standardized grpc TensorFlow server based instances using Compute Engine instance groups for distributed training.
PiperOrigin-RevId: 161443891
Diffstat (limited to 'tensorflow/contrib/cluster_resolver')
4 files changed, 397 insertions, 1 deletions
diff --git a/tensorflow/contrib/cluster_resolver/BUILD b/tensorflow/contrib/cluster_resolver/BUILD index 0dfc5a81d5..1e2c17eb52 100644 --- a/tensorflow/contrib/cluster_resolver/BUILD +++ b/tensorflow/contrib/cluster_resolver/BUILD @@ -22,9 +22,18 @@ filegroup( ) py_library( + name = "cluster_resolver_pip", + srcs = ["python/training/__init__.py"], + srcs_version = "PY2AND3", + deps = [ + ":cluster_resolver_py", + ":gce_cluster_resolver_py", + ], +) + +py_library( name = "cluster_resolver_py", srcs = [ - "python/training/__init__.py", "python/training/cluster_resolver.py", ], srcs_version = "PY2AND3", @@ -33,6 +42,15 @@ py_library( ], ) +py_library( + name = "gce_cluster_resolver_py", + srcs = [ + "python/training/gce_cluster_resolver.py", + ], + srcs_version = "PY2AND3", + deps = [":cluster_resolver_py"], +) + tf_py_test( name = "cluster_resolver_py_test", size = "small", @@ -46,3 +64,17 @@ tf_py_test( ], main = "python/training/cluster_resolver_test.py", ) + +tf_py_test( + name = "gce_cluster_resolver_py_test", + size = "small", + srcs = ["python/training/gce_cluster_resolver_test.py"], + additional_deps = [ + ":gce_cluster_resolver_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], + main = "python/training/gce_cluster_resolver_test.py", +) diff --git a/tensorflow/contrib/cluster_resolver/python/training/__init__.py b/tensorflow/contrib/cluster_resolver/python/training/__init__.py index 3520467bc6..fbf7ca3a5d 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/__init__.py +++ b/tensorflow/contrib/cluster_resolver/python/training/__init__.py @@ -21,3 +21,4 @@ from __future__ import print_function from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import SimpleClusterResolver from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import UnionClusterResolver +from tensorflow.contrib.cluster_resolver.python.training.gce_cluster_resolver import GceClusterResolver diff --git a/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py new file mode 100644 index 0000000000..2603d59920 --- /dev/null +++ b/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py @@ -0,0 +1,129 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implementation of Cluster Resolvers for GCE Instance Groups.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver +from tensorflow.python.training.server_lib import ClusterSpec + +_GOOGLE_API_CLIENT_INSTALLED = True +try: + from googleapiclient import discovery # pylint: disable=g-import-not-at-top +except ImportError: + _GOOGLE_API_CLIENT_INSTALLED = False + + +class GceClusterResolver(ClusterResolver): + """Cluster Resolver for Google Compute Engine. + + This is an implementation of cluster resolvers for the Google Compute Engine + instance group platform. By specifying a project, zone, and instance group, + this will retrieve the IP address of all the instances within the instance + group and return a Cluster Resolver object suitable for use for distributed + TensorFlow. + """ + + def __init__(self, + project, + zone, + instance_group, + port, + job_name='worker', + credentials=None, + service=None): + """Creates a new GceClusterResolver object. + + This takes in a few parameters and creates a GceClusterResolver project. It + will then use these parameters to query the GCE API for the IP addresses of + each instance in the instance group. + + Args: + project: Name of the GCE project + zone: Zone of the GCE instance group + instance_group: Name of the GCE instance group + port: Port of the listening TensorFlow server (default: 8470) + job_name: Name of the TensorFlow job this set of instances belongs to + credentials: GCE Credentials. This defaults to + GoogleCredentials.get_application_default() + service: The GCE API object returned by the googleapiclient.discovery + function. (Default: discovery.build('compute', 'v1')). If you specify a + custom service object, then the credentials parameter will be ignored. + + Raises: + ImportError: If the googleapiclient is not installed. + """ + self._project = project + self._zone = zone + self._instance_group = instance_group + self._job_name = job_name + self._port = port + if service is None: + if _GOOGLE_API_CLIENT_INSTALLED is True: + self._service = discovery.build('compute', 'v1', + credentials=credentials) + else: + raise ImportError('googleapiclient must be installed before using the ' + 'GCE cluster resolver') + else: + self._service = service + + def cluster_spec(self): + """Returns a ClusterSpec object based on the latest instance group info. + + This returns a ClusterSpec object for use based on information from the + specified instance group. We will retrieve the information from the GCE APIs + every time this method is called. + + Returns: + A ClusterSpec containing host information retrieved from GCE. + """ + request_body = {'instanceState': 'RUNNING'} + request = self._service.instanceGroups().listInstances( + project=self._project, + zone=self._zone, + instanceGroups=self._instance_group, + body=request_body, + orderBy='name') + + worker_list = [] + + while request is not None: + response = request.execute() + + items = response['items'] + for instance in items: + instance_name = instance['instance'].split('/')[-1] + + instance_request = self._service.instances().get( + project=self._project, + zone=self._zone, + instance=instance_name) + + if instance_request is not None: + instance_details = instance_request.execute() + ip_address = instance_details['networkInterfaces'][0]['networkIP'] + instance_url = '%s:%s' % (ip_address, self._port) + worker_list.append(instance_url) + + request = self._service.instanceGroups().listInstances_next( + previous_request=request, + previous_response=response) + + worker_list.sort() + return ClusterSpec({self._job_name: worker_list}) diff --git a/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver_test.py new file mode 100644 index 0000000000..f2deacbc26 --- /dev/null +++ b/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver_test.py @@ -0,0 +1,234 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for GceClusterResolver.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import UnionClusterResolver +from tensorflow.contrib.cluster_resolver.python.training.gce_cluster_resolver import GceClusterResolver +from tensorflow.python.platform import test +from tensorflow.python.training import server_lib + + +mock = test.mock + + +class GceClusterResolverTest(test.TestCase): + + def _verifyClusterSpecEquality(self, cluster_spec, expected_proto): + self.assertProtoEquals(expected_proto, cluster_spec.as_cluster_def()) + self.assertProtoEquals( + expected_proto, server_lib.ClusterSpec(cluster_spec).as_cluster_def()) + self.assertProtoEquals( + expected_proto, + server_lib.ClusterSpec(cluster_spec.as_cluster_def()).as_cluster_def()) + self.assertProtoEquals( + expected_proto, + server_lib.ClusterSpec(cluster_spec.as_dict()).as_cluster_def()) + + def standard_mock_instance_groups(self, instance_map=None): + if instance_map is None: + instance_map = [ + {'instance': 'https://gce.example.com/res/gce-instance-1'} + ] + + mock_instance_group_request = mock.MagicMock() + mock_instance_group_request.execute.return_value = { + 'items': instance_map + } + + service_attrs = { + 'listInstances.return_value': mock_instance_group_request, + 'listInstances_next.return_value': None, + } + mock_instance_groups = mock.Mock(**service_attrs) + return mock_instance_groups + + def standard_mock_instances(self, instance_to_ip_map=None): + if instance_to_ip_map is None: + instance_to_ip_map = { + 'gce-instance-1': '10.123.45.67' + } + + mock_get_request = mock.MagicMock() + mock_get_request.execute.return_value = { + 'networkInterfaces': [ + {'networkIP': '10.123.45.67'} + ] + } + + def get_side_effect(project, zone, instance): + del project, zone # Unused + + if instance in instance_to_ip_map: + mock_get_request = mock.MagicMock() + mock_get_request.execute.return_value = { + 'networkInterfaces': [ + {'networkIP': instance_to_ip_map[instance]} + ] + } + return mock_get_request + else: + raise RuntimeError('Instance %s not found!' % instance) + + service_attrs = { + 'get.side_effect': get_side_effect, + } + mock_instances = mock.MagicMock(**service_attrs) + return mock_instances + + def standard_mock_service_client( + self, + mock_instance_groups=None, + mock_instances=None): + + if mock_instance_groups is None: + mock_instance_groups = self.standard_mock_instance_groups() + if mock_instances is None: + mock_instances = self.standard_mock_instances() + + mock_client = mock.MagicMock() + mock_client.instanceGroups.return_value = mock_instance_groups + mock_client.instances.return_value = mock_instances + return mock_client + + def gen_standard_mock_service_client(self, instances=None): + name_to_ip = {} + instance_list = [] + for instance in instances: + name_to_ip[instance['name']] = instance['ip'] + instance_list.append({ + 'instance': 'https://gce.example.com/gce/res/' + instance['name'] + }) + + mock_instance = self.standard_mock_instances(name_to_ip) + mock_instance_group = self.standard_mock_instance_groups(instance_list) + + return self.standard_mock_service_client(mock_instance_group, mock_instance) + + def testSimpleSuccessfulRetrieval(self): + gce_cluster_resolver = GceClusterResolver( + project='test-project', + zone='us-east1-d', + instance_group='test-instance-group', + port=8470, + service=self.standard_mock_service_client()) + + actual_cluster_spec = gce_cluster_resolver.cluster_spec() + expected_proto = """ + job { name: 'worker' tasks { key: 0 value: '10.123.45.67:8470' } } + """ + self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + + def testCustomJobNameAndPortRetrieval(self): + gce_cluster_resolver = GceClusterResolver( + project='test-project', + zone='us-east1-d', + instance_group='test-instance-group', + job_name='custom', + port=2222, + service=self.standard_mock_service_client()) + + actual_cluster_spec = gce_cluster_resolver.cluster_spec() + expected_proto = """ + job { name: 'custom' tasks { key: 0 value: '10.123.45.67:2222' } } + """ + self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + + def testMultipleInstancesRetrieval(self): + name_to_ip = [ + {'name': 'instance1', 'ip': '10.1.2.3'}, + {'name': 'instance2', 'ip': '10.2.3.4'}, + {'name': 'instance3', 'ip': '10.3.4.5'}, + ] + + gce_cluster_resolver = GceClusterResolver( + project='test-project', + zone='us-east1-d', + instance_group='test-instance-group', + port=8470, + service=self.gen_standard_mock_service_client(name_to_ip)) + + actual_cluster_spec = gce_cluster_resolver.cluster_spec() + expected_proto = """ + job { name: 'worker' tasks { key: 0 value: '10.1.2.3:8470' } + tasks { key: 1 value: '10.2.3.4:8470' } + tasks { key: 2 value: '10.3.4.5:8470' } } + """ + self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + + def testUnionMultipleInstanceRetrieval(self): + worker1_name_to_ip = [ + {'name': 'instance1', 'ip': '10.1.2.3'}, + {'name': 'instance2', 'ip': '10.2.3.4'}, + {'name': 'instance3', 'ip': '10.3.4.5'}, + ] + + worker2_name_to_ip = [ + {'name': 'instance4', 'ip': '10.4.5.6'}, + {'name': 'instance5', 'ip': '10.5.6.7'}, + {'name': 'instance6', 'ip': '10.6.7.8'}, + ] + + ps_name_to_ip = [ + {'name': 'ps1', 'ip': '10.100.1.2'}, + {'name': 'ps2', 'ip': '10.100.2.3'}, + ] + + worker1_gce_cluster_resolver = GceClusterResolver( + project='test-project', + zone='us-east1-d', + instance_group='test-instance-group', + job_name='worker', + port=8470, + service=self.gen_standard_mock_service_client(worker1_name_to_ip)) + + worker2_gce_cluster_resolver = GceClusterResolver( + project='test-project', + zone='us-east1-d', + instance_group='test-instance-group', + job_name='worker', + port=8470, + service=self.gen_standard_mock_service_client(worker2_name_to_ip)) + + ps_gce_cluster_resolver = GceClusterResolver( + project='test-project', + zone='us-east1-d', + instance_group='test-instance-group', + job_name='ps', + port=2222, + service=self.gen_standard_mock_service_client(ps_name_to_ip)) + + union_cluster_resolver = UnionClusterResolver(worker1_gce_cluster_resolver, + worker2_gce_cluster_resolver, + ps_gce_cluster_resolver) + + actual_cluster_spec = union_cluster_resolver.cluster_spec() + expected_proto = """ + job { name: 'ps' tasks { key: 0 value: '10.100.1.2:2222' } + tasks { key: 1 value: '10.100.2.3:2222' } } + job { name: 'worker' tasks { key: 0 value: '10.1.2.3:8470' } + tasks { key: 1 value: '10.2.3.4:8470' } + tasks { key: 2 value: '10.3.4.5:8470' } + tasks { key: 3 value: '10.4.5.6:8470' } + tasks { key: 4 value: '10.5.6.7:8470' } + tasks { key: 5 value: '10.6.7.8:8470' } } + """ + self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + +if __name__ == '__main__': + test.main() |