aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/cluster_resolver
diff options
context:
space:
mode:
authorGravatar Frank Chen <frankchn@google.com>2017-07-10 14:46:57 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-10 14:50:48 -0700
commit8b53854ea5e4b7ebeec4a8636580f54424cb1952 (patch)
treef0c69365fe76329d9376e2e82243fc13b0525cc2 /tensorflow/contrib/cluster_resolver
parente121535a7d04cfc7c7dbb09d8694c01eb29da26f (diff)
Adds support for retrieving instances from the Google Compute Engine instance group APIs, with support (in conjunction with UnionClusterResolver) for mapping multiple instance groups into one TensorFlow job (see the `testUnionMultipleInstanceRetrieval` test for details).
This should simplify creating and using standardized grpc TensorFlow server based instances using Compute Engine instance groups for distributed training. PiperOrigin-RevId: 161443891
Diffstat (limited to 'tensorflow/contrib/cluster_resolver')
-rw-r--r--tensorflow/contrib/cluster_resolver/BUILD34
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/__init__.py1
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py129
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver_test.py234
4 files changed, 397 insertions, 1 deletions
diff --git a/tensorflow/contrib/cluster_resolver/BUILD b/tensorflow/contrib/cluster_resolver/BUILD
index 0dfc5a81d5..1e2c17eb52 100644
--- a/tensorflow/contrib/cluster_resolver/BUILD
+++ b/tensorflow/contrib/cluster_resolver/BUILD
@@ -22,9 +22,18 @@ filegroup(
)
py_library(
+ name = "cluster_resolver_pip",
+ srcs = ["python/training/__init__.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":cluster_resolver_py",
+ ":gce_cluster_resolver_py",
+ ],
+)
+
+py_library(
name = "cluster_resolver_py",
srcs = [
- "python/training/__init__.py",
"python/training/cluster_resolver.py",
],
srcs_version = "PY2AND3",
@@ -33,6 +42,15 @@ py_library(
],
)
+py_library(
+ name = "gce_cluster_resolver_py",
+ srcs = [
+ "python/training/gce_cluster_resolver.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [":cluster_resolver_py"],
+)
+
tf_py_test(
name = "cluster_resolver_py_test",
size = "small",
@@ -46,3 +64,17 @@ tf_py_test(
],
main = "python/training/cluster_resolver_test.py",
)
+
+tf_py_test(
+ name = "gce_cluster_resolver_py_test",
+ size = "small",
+ srcs = ["python/training/gce_cluster_resolver_test.py"],
+ additional_deps = [
+ ":gce_cluster_resolver_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+ main = "python/training/gce_cluster_resolver_test.py",
+)
diff --git a/tensorflow/contrib/cluster_resolver/python/training/__init__.py b/tensorflow/contrib/cluster_resolver/python/training/__init__.py
index 3520467bc6..fbf7ca3a5d 100644
--- a/tensorflow/contrib/cluster_resolver/python/training/__init__.py
+++ b/tensorflow/contrib/cluster_resolver/python/training/__init__.py
@@ -21,3 +21,4 @@ from __future__ import print_function
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import SimpleClusterResolver
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import UnionClusterResolver
+from tensorflow.contrib.cluster_resolver.python.training.gce_cluster_resolver import GceClusterResolver
diff --git a/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py
new file mode 100644
index 0000000000..2603d59920
--- /dev/null
+++ b/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py
@@ -0,0 +1,129 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Implementation of Cluster Resolvers for GCE Instance Groups."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver
+from tensorflow.python.training.server_lib import ClusterSpec
+
+_GOOGLE_API_CLIENT_INSTALLED = True
+try:
+ from googleapiclient import discovery # pylint: disable=g-import-not-at-top
+except ImportError:
+ _GOOGLE_API_CLIENT_INSTALLED = False
+
+
+class GceClusterResolver(ClusterResolver):
+ """Cluster Resolver for Google Compute Engine.
+
+ This is an implementation of cluster resolvers for the Google Compute Engine
+ instance group platform. By specifying a project, zone, and instance group,
+ this will retrieve the IP address of all the instances within the instance
+ group and return a Cluster Resolver object suitable for use for distributed
+ TensorFlow.
+ """
+
+ def __init__(self,
+ project,
+ zone,
+ instance_group,
+ port,
+ job_name='worker',
+ credentials=None,
+ service=None):
+ """Creates a new GceClusterResolver object.
+
+ This takes in a few parameters and creates a GceClusterResolver project. It
+ will then use these parameters to query the GCE API for the IP addresses of
+ each instance in the instance group.
+
+ Args:
+ project: Name of the GCE project
+ zone: Zone of the GCE instance group
+ instance_group: Name of the GCE instance group
+ port: Port of the listening TensorFlow server (default: 8470)
+ job_name: Name of the TensorFlow job this set of instances belongs to
+ credentials: GCE Credentials. This defaults to
+ GoogleCredentials.get_application_default()
+ service: The GCE API object returned by the googleapiclient.discovery
+ function. (Default: discovery.build('compute', 'v1')). If you specify a
+ custom service object, then the credentials parameter will be ignored.
+
+ Raises:
+ ImportError: If the googleapiclient is not installed.
+ """
+ self._project = project
+ self._zone = zone
+ self._instance_group = instance_group
+ self._job_name = job_name
+ self._port = port
+ if service is None:
+ if _GOOGLE_API_CLIENT_INSTALLED is True:
+ self._service = discovery.build('compute', 'v1',
+ credentials=credentials)
+ else:
+ raise ImportError('googleapiclient must be installed before using the '
+ 'GCE cluster resolver')
+ else:
+ self._service = service
+
+ def cluster_spec(self):
+ """Returns a ClusterSpec object based on the latest instance group info.
+
+ This returns a ClusterSpec object for use based on information from the
+ specified instance group. We will retrieve the information from the GCE APIs
+ every time this method is called.
+
+ Returns:
+ A ClusterSpec containing host information retrieved from GCE.
+ """
+ request_body = {'instanceState': 'RUNNING'}
+ request = self._service.instanceGroups().listInstances(
+ project=self._project,
+ zone=self._zone,
+ instanceGroups=self._instance_group,
+ body=request_body,
+ orderBy='name')
+
+ worker_list = []
+
+ while request is not None:
+ response = request.execute()
+
+ items = response['items']
+ for instance in items:
+ instance_name = instance['instance'].split('/')[-1]
+
+ instance_request = self._service.instances().get(
+ project=self._project,
+ zone=self._zone,
+ instance=instance_name)
+
+ if instance_request is not None:
+ instance_details = instance_request.execute()
+ ip_address = instance_details['networkInterfaces'][0]['networkIP']
+ instance_url = '%s:%s' % (ip_address, self._port)
+ worker_list.append(instance_url)
+
+ request = self._service.instanceGroups().listInstances_next(
+ previous_request=request,
+ previous_response=response)
+
+ worker_list.sort()
+ return ClusterSpec({self._job_name: worker_list})
diff --git a/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver_test.py
new file mode 100644
index 0000000000..f2deacbc26
--- /dev/null
+++ b/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver_test.py
@@ -0,0 +1,234 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for GceClusterResolver."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import UnionClusterResolver
+from tensorflow.contrib.cluster_resolver.python.training.gce_cluster_resolver import GceClusterResolver
+from tensorflow.python.platform import test
+from tensorflow.python.training import server_lib
+
+
+mock = test.mock
+
+
+class GceClusterResolverTest(test.TestCase):
+
+ def _verifyClusterSpecEquality(self, cluster_spec, expected_proto):
+ self.assertProtoEquals(expected_proto, cluster_spec.as_cluster_def())
+ self.assertProtoEquals(
+ expected_proto, server_lib.ClusterSpec(cluster_spec).as_cluster_def())
+ self.assertProtoEquals(
+ expected_proto,
+ server_lib.ClusterSpec(cluster_spec.as_cluster_def()).as_cluster_def())
+ self.assertProtoEquals(
+ expected_proto,
+ server_lib.ClusterSpec(cluster_spec.as_dict()).as_cluster_def())
+
+ def standard_mock_instance_groups(self, instance_map=None):
+ if instance_map is None:
+ instance_map = [
+ {'instance': 'https://gce.example.com/res/gce-instance-1'}
+ ]
+
+ mock_instance_group_request = mock.MagicMock()
+ mock_instance_group_request.execute.return_value = {
+ 'items': instance_map
+ }
+
+ service_attrs = {
+ 'listInstances.return_value': mock_instance_group_request,
+ 'listInstances_next.return_value': None,
+ }
+ mock_instance_groups = mock.Mock(**service_attrs)
+ return mock_instance_groups
+
+ def standard_mock_instances(self, instance_to_ip_map=None):
+ if instance_to_ip_map is None:
+ instance_to_ip_map = {
+ 'gce-instance-1': '10.123.45.67'
+ }
+
+ mock_get_request = mock.MagicMock()
+ mock_get_request.execute.return_value = {
+ 'networkInterfaces': [
+ {'networkIP': '10.123.45.67'}
+ ]
+ }
+
+ def get_side_effect(project, zone, instance):
+ del project, zone # Unused
+
+ if instance in instance_to_ip_map:
+ mock_get_request = mock.MagicMock()
+ mock_get_request.execute.return_value = {
+ 'networkInterfaces': [
+ {'networkIP': instance_to_ip_map[instance]}
+ ]
+ }
+ return mock_get_request
+ else:
+ raise RuntimeError('Instance %s not found!' % instance)
+
+ service_attrs = {
+ 'get.side_effect': get_side_effect,
+ }
+ mock_instances = mock.MagicMock(**service_attrs)
+ return mock_instances
+
+ def standard_mock_service_client(
+ self,
+ mock_instance_groups=None,
+ mock_instances=None):
+
+ if mock_instance_groups is None:
+ mock_instance_groups = self.standard_mock_instance_groups()
+ if mock_instances is None:
+ mock_instances = self.standard_mock_instances()
+
+ mock_client = mock.MagicMock()
+ mock_client.instanceGroups.return_value = mock_instance_groups
+ mock_client.instances.return_value = mock_instances
+ return mock_client
+
+ def gen_standard_mock_service_client(self, instances=None):
+ name_to_ip = {}
+ instance_list = []
+ for instance in instances:
+ name_to_ip[instance['name']] = instance['ip']
+ instance_list.append({
+ 'instance': 'https://gce.example.com/gce/res/' + instance['name']
+ })
+
+ mock_instance = self.standard_mock_instances(name_to_ip)
+ mock_instance_group = self.standard_mock_instance_groups(instance_list)
+
+ return self.standard_mock_service_client(mock_instance_group, mock_instance)
+
+ def testSimpleSuccessfulRetrieval(self):
+ gce_cluster_resolver = GceClusterResolver(
+ project='test-project',
+ zone='us-east1-d',
+ instance_group='test-instance-group',
+ port=8470,
+ service=self.standard_mock_service_client())
+
+ actual_cluster_spec = gce_cluster_resolver.cluster_spec()
+ expected_proto = """
+ job { name: 'worker' tasks { key: 0 value: '10.123.45.67:8470' } }
+ """
+ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
+
+ def testCustomJobNameAndPortRetrieval(self):
+ gce_cluster_resolver = GceClusterResolver(
+ project='test-project',
+ zone='us-east1-d',
+ instance_group='test-instance-group',
+ job_name='custom',
+ port=2222,
+ service=self.standard_mock_service_client())
+
+ actual_cluster_spec = gce_cluster_resolver.cluster_spec()
+ expected_proto = """
+ job { name: 'custom' tasks { key: 0 value: '10.123.45.67:2222' } }
+ """
+ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
+
+ def testMultipleInstancesRetrieval(self):
+ name_to_ip = [
+ {'name': 'instance1', 'ip': '10.1.2.3'},
+ {'name': 'instance2', 'ip': '10.2.3.4'},
+ {'name': 'instance3', 'ip': '10.3.4.5'},
+ ]
+
+ gce_cluster_resolver = GceClusterResolver(
+ project='test-project',
+ zone='us-east1-d',
+ instance_group='test-instance-group',
+ port=8470,
+ service=self.gen_standard_mock_service_client(name_to_ip))
+
+ actual_cluster_spec = gce_cluster_resolver.cluster_spec()
+ expected_proto = """
+ job { name: 'worker' tasks { key: 0 value: '10.1.2.3:8470' }
+ tasks { key: 1 value: '10.2.3.4:8470' }
+ tasks { key: 2 value: '10.3.4.5:8470' } }
+ """
+ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
+
+ def testUnionMultipleInstanceRetrieval(self):
+ worker1_name_to_ip = [
+ {'name': 'instance1', 'ip': '10.1.2.3'},
+ {'name': 'instance2', 'ip': '10.2.3.4'},
+ {'name': 'instance3', 'ip': '10.3.4.5'},
+ ]
+
+ worker2_name_to_ip = [
+ {'name': 'instance4', 'ip': '10.4.5.6'},
+ {'name': 'instance5', 'ip': '10.5.6.7'},
+ {'name': 'instance6', 'ip': '10.6.7.8'},
+ ]
+
+ ps_name_to_ip = [
+ {'name': 'ps1', 'ip': '10.100.1.2'},
+ {'name': 'ps2', 'ip': '10.100.2.3'},
+ ]
+
+ worker1_gce_cluster_resolver = GceClusterResolver(
+ project='test-project',
+ zone='us-east1-d',
+ instance_group='test-instance-group',
+ job_name='worker',
+ port=8470,
+ service=self.gen_standard_mock_service_client(worker1_name_to_ip))
+
+ worker2_gce_cluster_resolver = GceClusterResolver(
+ project='test-project',
+ zone='us-east1-d',
+ instance_group='test-instance-group',
+ job_name='worker',
+ port=8470,
+ service=self.gen_standard_mock_service_client(worker2_name_to_ip))
+
+ ps_gce_cluster_resolver = GceClusterResolver(
+ project='test-project',
+ zone='us-east1-d',
+ instance_group='test-instance-group',
+ job_name='ps',
+ port=2222,
+ service=self.gen_standard_mock_service_client(ps_name_to_ip))
+
+ union_cluster_resolver = UnionClusterResolver(worker1_gce_cluster_resolver,
+ worker2_gce_cluster_resolver,
+ ps_gce_cluster_resolver)
+
+ actual_cluster_spec = union_cluster_resolver.cluster_spec()
+ expected_proto = """
+ job { name: 'ps' tasks { key: 0 value: '10.100.1.2:2222' }
+ tasks { key: 1 value: '10.100.2.3:2222' } }
+ job { name: 'worker' tasks { key: 0 value: '10.1.2.3:8470' }
+ tasks { key: 1 value: '10.2.3.4:8470' }
+ tasks { key: 2 value: '10.3.4.5:8470' }
+ tasks { key: 3 value: '10.4.5.6:8470' }
+ tasks { key: 4 value: '10.5.6.7:8470' }
+ tasks { key: 5 value: '10.6.7.8:8470' } }
+ """
+ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
+
+if __name__ == '__main__':
+ test.main()