diff options
author | 2017-07-27 15:05:16 -0700 | |
---|---|---|
committer | 2017-07-27 15:08:54 -0700 | |
commit | 28373cfe70dbb69031295fb3254e56f8b765b229 (patch) | |
tree | f0b0abc7b6d5ef83f8c08f885decb72bc874cca2 /tensorflow/contrib/cluster_resolver | |
parent | e5353c941c4cfd7f256d69cc50caf6c90e70dd4a (diff) |
Adds preliminary support for Cloud TPUs with Cluster Resolvers. This aims to allow users to have a better experienec when specifying one or multiple Cloud TPUs for their training jobs by allowing users to use names rather than IP addresses.
PiperOrigin-RevId: 163393443
Diffstat (limited to 'tensorflow/contrib/cluster_resolver')
4 files changed, 244 insertions, 0 deletions
diff --git a/tensorflow/contrib/cluster_resolver/BUILD b/tensorflow/contrib/cluster_resolver/BUILD index ece5b9c04c..9d27d3c161 100644 --- a/tensorflow/contrib/cluster_resolver/BUILD +++ b/tensorflow/contrib/cluster_resolver/BUILD @@ -28,6 +28,7 @@ py_library( deps = [ ":cluster_resolver_py", ":gce_cluster_resolver_py", + ":tpu_cluster_resolver_py", ], ) @@ -54,6 +55,18 @@ py_library( ], ) +py_library( + name = "tpu_cluster_resolver_py", + srcs = [ + "python/training/tpu_cluster_resolver.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":cluster_resolver_py", + "//tensorflow/python:training", + ], +) + tf_py_test( name = "cluster_resolver_py_test", size = "small", @@ -81,3 +94,17 @@ tf_py_test( ], main = "python/training/gce_cluster_resolver_test.py", ) + +tf_py_test( + name = "tpu_cluster_resolver_py_test", + size = "small", + srcs = ["python/training/tpu_cluster_resolver_test.py"], + additional_deps = [ + ":tpu_cluster_resolver_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], + main = "python/training/tpu_cluster_resolver_test.py", +) diff --git a/tensorflow/contrib/cluster_resolver/python/training/__init__.py b/tensorflow/contrib/cluster_resolver/python/training/__init__.py index fbf7ca3a5d..0b0464b7d2 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/__init__.py +++ b/tensorflow/contrib/cluster_resolver/python/training/__init__.py @@ -22,3 +22,4 @@ from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import SimpleClusterResolver from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import UnionClusterResolver from tensorflow.contrib.cluster_resolver.python.training.gce_cluster_resolver import GceClusterResolver +from tensorflow.contrib.cluster_resolver.python.training.tpu_cluster_resolver import TPUClusterResolver diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py new file mode 100644 index 0000000000..2edf3b599a --- /dev/null +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -0,0 +1,105 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implementation of Cluster Resolvers for Cloud TPUs.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver +from tensorflow.python.training.server_lib import ClusterSpec + +_GOOGLE_API_CLIENT_INSTALLED = True +try: + from googleapiclient import discovery # pylint: disable=g-import-not-at-top +except ImportError: + _GOOGLE_API_CLIENT_INSTALLED = False + + +class TPUClusterResolver(ClusterResolver): + """Cluster Resolver for Google Cloud TPUs. + + This is an implementation of cluster resolvers for the Google Cloud TPU + service. As Cloud TPUs are in alpha, you will need to specify a API definition + file for this to consume, in addition to a list of Cloud TPUs in your Google + Cloud Platform project. + """ + + def __init__(self, + api_definition, + project, + zone, + tpu_names, + credentials, + job_name='tpu_worker', + service=None): + """Creates a new TPUClusterResolver object. + + The ClusterResolver will then use the parameters to query the Cloud TPU APIs + for the IP addresses and ports of each Cloud TPU listed. + + Args: + api_definition: (Alpha only) A copy of the JSON API definitions for + Cloud TPUs. This will be removed once Cloud TPU enters beta. + project: Name of the GCP project containing Cloud TPUs + zone: Zone where the TPUs are located + tpu_names: A list of names of the target Cloud TPUs. + credentials: GCE Credentials. + job_name: Name of the TensorFlow job the TPUs belong to. + service: The GCE API object returned by the googleapiclient.discovery + function. If you specify a custom service object, then the credentials + parameter will be ignored. + + Raises: + ImportError: If the googleapiclient is not installed. + """ + + self._project = project + self._zone = zone + self._tpu_names = tpu_names + self._job_name = job_name + if service is None: + if not _GOOGLE_API_CLIENT_INSTALLED: + raise ImportError('googleapiclient must be installed before using the ' + 'TPU cluster resolver') + + # TODO(frankchn): Remove once Cloud TPU API Definitions are public and + # replace with discovery.build('tpu', 'v1') + self._service = discovery.build_from_document(api_definition, + credentials=credentials) + else: + self._service = service + + def cluster_spec(self): + """Returns a ClusterSpec object based on the latest TPU information. + + We retrieve the information from the GCE APIs every time this method is + called. + + Returns: + A ClusterSpec containing host information returned from Cloud TPUs. + """ + worker_list = [] + + for tpu_name in self._tpu_names: + full_name = 'projects/%s/locations/%s/nodes/%s' % ( + self._project, self._zone, tpu_name) + request = self._service.projects().locations().nodes().get(name=full_name) + response = request.execute() + + instance_url = '%s:%s' % (response.ipAddress, response.port) + worker_list.append(instance_url) + + return ClusterSpec({self._job_name: worker_list}) diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py new file mode 100644 index 0000000000..5bd5cd1a87 --- /dev/null +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py @@ -0,0 +1,111 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for TPUClusterResolver.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.cluster_resolver.python.training.tpu_cluster_resolver import TPUClusterResolver +from tensorflow.python.platform import test +from tensorflow.python.training import server_lib + + +mock = test.mock + + +class TPUClusterResolverTest(test.TestCase): + + def _verifyClusterSpecEquality(self, cluster_spec, expected_proto): + """Verifies that the ClusterSpec generates the correct proto. + + We are testing this four different ways to ensure that the ClusterSpec + returned by the TPUClusterResolver behaves identically to a normal + ClusterSpec when passed into the generic ClusterSpec libraries. + + Args: + cluster_spec: ClusterSpec returned by the TPUClusterResolver + expected_proto: Expected protobuf + """ + self.assertProtoEquals(expected_proto, cluster_spec.as_cluster_def()) + self.assertProtoEquals( + expected_proto, server_lib.ClusterSpec(cluster_spec).as_cluster_def()) + self.assertProtoEquals( + expected_proto, + server_lib.ClusterSpec(cluster_spec.as_cluster_def()).as_cluster_def()) + self.assertProtoEquals( + expected_proto, + server_lib.ClusterSpec(cluster_spec.as_dict()).as_cluster_def()) + + def mock_service_client( + self, + tpu_map=None): + + if tpu_map is None: + tpu_map = {} + + def get_side_effect(name): + return tpu_map[name] + + mock_client = mock.MagicMock() + mock_client.projects.locations.nodes.get.side_effect = get_side_effect + return mock_client + + def testSimpleSuccessfulRetrieval(self): + tpu_map = { + 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { + 'ipAddress': '10.1.2.3', + 'port': '8470' + } + } + + tpu_cluster_resolver = TPUClusterResolver( + project='test-project', + zone='us-central1-c', + tpu_names=['test-tpu-1'], + credentials=None, + service=self.mock_service_client(tpu_map=tpu_map)) + + actual_cluster_spec = tpu_cluster_resolver.cluster_spec() + expected_proto = """ + job { name: 'tpu_worker' tasks { key: 0 value: '10.1.2.3:8470' } } + """ + self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + + def testMultipleSuccessfulRetrieval(self): + tpu_map = { + 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { + 'ipAddress': '10.1.2.3', + 'port': '8470' + }, + 'projects/test-project/locations/us-central1-c/nodes/test-tpu-2': { + 'ipAddress': '10.4.5.6', + 'port': '8470' + } + } + + tpu_cluster_resolver = TPUClusterResolver( + project='test-project', + zone='us-central1-c', + tpu_names=['test-tpu-2', 'test-tpu-1'], + credentials=None, + service=self.mock_service_client(tpu_map=tpu_map)) + + actual_cluster_spec = tpu_cluster_resolver.cluster_spec() + expected_proto = """ + job { name: 'tpu_worker' tasks { key: 0 value: '10.4.5.6:8470' } + tasks { key: 1 value: '10.1.2.3:8470' } } + """ + self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) |