aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py')
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py105
1 files changed, 105 insertions, 0 deletions
diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
new file mode 100644
index 0000000000..2edf3b599a
--- /dev/null
+++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
@@ -0,0 +1,105 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Implementation of Cluster Resolvers for Cloud TPUs."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver
+from tensorflow.python.training.server_lib import ClusterSpec
+
+_GOOGLE_API_CLIENT_INSTALLED = True
+try:
+ from googleapiclient import discovery # pylint: disable=g-import-not-at-top
+except ImportError:
+ _GOOGLE_API_CLIENT_INSTALLED = False
+
+
+class TPUClusterResolver(ClusterResolver):
+ """Cluster Resolver for Google Cloud TPUs.
+
+ This is an implementation of cluster resolvers for the Google Cloud TPU
+ service. As Cloud TPUs are in alpha, you will need to specify a API definition
+ file for this to consume, in addition to a list of Cloud TPUs in your Google
+ Cloud Platform project.
+ """
+
+ def __init__(self,
+ api_definition,
+ project,
+ zone,
+ tpu_names,
+ credentials,
+ job_name='tpu_worker',
+ service=None):
+ """Creates a new TPUClusterResolver object.
+
+ The ClusterResolver will then use the parameters to query the Cloud TPU APIs
+ for the IP addresses and ports of each Cloud TPU listed.
+
+ Args:
+ api_definition: (Alpha only) A copy of the JSON API definitions for
+ Cloud TPUs. This will be removed once Cloud TPU enters beta.
+ project: Name of the GCP project containing Cloud TPUs
+ zone: Zone where the TPUs are located
+ tpu_names: A list of names of the target Cloud TPUs.
+ credentials: GCE Credentials.
+ job_name: Name of the TensorFlow job the TPUs belong to.
+ service: The GCE API object returned by the googleapiclient.discovery
+ function. If you specify a custom service object, then the credentials
+ parameter will be ignored.
+
+ Raises:
+ ImportError: If the googleapiclient is not installed.
+ """
+
+ self._project = project
+ self._zone = zone
+ self._tpu_names = tpu_names
+ self._job_name = job_name
+ if service is None:
+ if not _GOOGLE_API_CLIENT_INSTALLED:
+ raise ImportError('googleapiclient must be installed before using the '
+ 'TPU cluster resolver')
+
+ # TODO(frankchn): Remove once Cloud TPU API Definitions are public and
+ # replace with discovery.build('tpu', 'v1')
+ self._service = discovery.build_from_document(api_definition,
+ credentials=credentials)
+ else:
+ self._service = service
+
+ def cluster_spec(self):
+ """Returns a ClusterSpec object based on the latest TPU information.
+
+ We retrieve the information from the GCE APIs every time this method is
+ called.
+
+ Returns:
+ A ClusterSpec containing host information returned from Cloud TPUs.
+ """
+ worker_list = []
+
+ for tpu_name in self._tpu_names:
+ full_name = 'projects/%s/locations/%s/nodes/%s' % (
+ self._project, self._zone, tpu_name)
+ request = self._service.projects().locations().nodes().get(name=full_name)
+ response = request.execute()
+
+ instance_url = '%s:%s' % (response.ipAddress, response.port)
+ worker_list.append(instance_url)
+
+ return ClusterSpec({self._job_name: worker_list})