aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/cluster_resolver
diff options
context:
space:
mode:
authorGravatar Frank Chen <frankchn@google.com>2017-07-27 15:05:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-27 15:08:54 -0700
commit28373cfe70dbb69031295fb3254e56f8b765b229 (patch)
treef0b0abc7b6d5ef83f8c08f885decb72bc874cca2 /tensorflow/contrib/cluster_resolver
parente5353c941c4cfd7f256d69cc50caf6c90e70dd4a (diff)
Adds preliminary support for Cloud TPUs with Cluster Resolvers. This aims to allow users to have a better experienec when specifying one or multiple Cloud TPUs for their training jobs by allowing users to use names rather than IP addresses.
PiperOrigin-RevId: 163393443
Diffstat (limited to 'tensorflow/contrib/cluster_resolver')
-rw-r--r--tensorflow/contrib/cluster_resolver/BUILD27
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/__init__.py1
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py105
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py111
4 files changed, 244 insertions, 0 deletions
diff --git a/tensorflow/contrib/cluster_resolver/BUILD b/tensorflow/contrib/cluster_resolver/BUILD
index ece5b9c04c..9d27d3c161 100644
--- a/tensorflow/contrib/cluster_resolver/BUILD
+++ b/tensorflow/contrib/cluster_resolver/BUILD
@@ -28,6 +28,7 @@ py_library(
deps = [
":cluster_resolver_py",
":gce_cluster_resolver_py",
+ ":tpu_cluster_resolver_py",
],
)
@@ -54,6 +55,18 @@ py_library(
],
)
+py_library(
+ name = "tpu_cluster_resolver_py",
+ srcs = [
+ "python/training/tpu_cluster_resolver.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":cluster_resolver_py",
+ "//tensorflow/python:training",
+ ],
+)
+
tf_py_test(
name = "cluster_resolver_py_test",
size = "small",
@@ -81,3 +94,17 @@ tf_py_test(
],
main = "python/training/gce_cluster_resolver_test.py",
)
+
+tf_py_test(
+ name = "tpu_cluster_resolver_py_test",
+ size = "small",
+ srcs = ["python/training/tpu_cluster_resolver_test.py"],
+ additional_deps = [
+ ":tpu_cluster_resolver_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+ main = "python/training/tpu_cluster_resolver_test.py",
+)
diff --git a/tensorflow/contrib/cluster_resolver/python/training/__init__.py b/tensorflow/contrib/cluster_resolver/python/training/__init__.py
index fbf7ca3a5d..0b0464b7d2 100644
--- a/tensorflow/contrib/cluster_resolver/python/training/__init__.py
+++ b/tensorflow/contrib/cluster_resolver/python/training/__init__.py
@@ -22,3 +22,4 @@ from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import SimpleClusterResolver
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import UnionClusterResolver
from tensorflow.contrib.cluster_resolver.python.training.gce_cluster_resolver import GceClusterResolver
+from tensorflow.contrib.cluster_resolver.python.training.tpu_cluster_resolver import TPUClusterResolver
diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
new file mode 100644
index 0000000000..2edf3b599a
--- /dev/null
+++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
@@ -0,0 +1,105 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Implementation of Cluster Resolvers for Cloud TPUs."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver
+from tensorflow.python.training.server_lib import ClusterSpec
+
+_GOOGLE_API_CLIENT_INSTALLED = True
+try:
+ from googleapiclient import discovery # pylint: disable=g-import-not-at-top
+except ImportError:
+ _GOOGLE_API_CLIENT_INSTALLED = False
+
+
+class TPUClusterResolver(ClusterResolver):
+ """Cluster Resolver for Google Cloud TPUs.
+
+ This is an implementation of cluster resolvers for the Google Cloud TPU
+ service. As Cloud TPUs are in alpha, you will need to specify a API definition
+ file for this to consume, in addition to a list of Cloud TPUs in your Google
+ Cloud Platform project.
+ """
+
+ def __init__(self,
+ api_definition,
+ project,
+ zone,
+ tpu_names,
+ credentials,
+ job_name='tpu_worker',
+ service=None):
+ """Creates a new TPUClusterResolver object.
+
+ The ClusterResolver will then use the parameters to query the Cloud TPU APIs
+ for the IP addresses and ports of each Cloud TPU listed.
+
+ Args:
+ api_definition: (Alpha only) A copy of the JSON API definitions for
+ Cloud TPUs. This will be removed once Cloud TPU enters beta.
+ project: Name of the GCP project containing Cloud TPUs
+ zone: Zone where the TPUs are located
+ tpu_names: A list of names of the target Cloud TPUs.
+ credentials: GCE Credentials.
+ job_name: Name of the TensorFlow job the TPUs belong to.
+ service: The GCE API object returned by the googleapiclient.discovery
+ function. If you specify a custom service object, then the credentials
+ parameter will be ignored.
+
+ Raises:
+ ImportError: If the googleapiclient is not installed.
+ """
+
+ self._project = project
+ self._zone = zone
+ self._tpu_names = tpu_names
+ self._job_name = job_name
+ if service is None:
+ if not _GOOGLE_API_CLIENT_INSTALLED:
+ raise ImportError('googleapiclient must be installed before using the '
+ 'TPU cluster resolver')
+
+ # TODO(frankchn): Remove once Cloud TPU API Definitions are public and
+ # replace with discovery.build('tpu', 'v1')
+ self._service = discovery.build_from_document(api_definition,
+ credentials=credentials)
+ else:
+ self._service = service
+
+ def cluster_spec(self):
+ """Returns a ClusterSpec object based on the latest TPU information.
+
+ We retrieve the information from the GCE APIs every time this method is
+ called.
+
+ Returns:
+ A ClusterSpec containing host information returned from Cloud TPUs.
+ """
+ worker_list = []
+
+ for tpu_name in self._tpu_names:
+ full_name = 'projects/%s/locations/%s/nodes/%s' % (
+ self._project, self._zone, tpu_name)
+ request = self._service.projects().locations().nodes().get(name=full_name)
+ response = request.execute()
+
+ instance_url = '%s:%s' % (response.ipAddress, response.port)
+ worker_list.append(instance_url)
+
+ return ClusterSpec({self._job_name: worker_list})
diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py
new file mode 100644
index 0000000000..5bd5cd1a87
--- /dev/null
+++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py
@@ -0,0 +1,111 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for TPUClusterResolver."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.cluster_resolver.python.training.tpu_cluster_resolver import TPUClusterResolver
+from tensorflow.python.platform import test
+from tensorflow.python.training import server_lib
+
+
+mock = test.mock
+
+
+class TPUClusterResolverTest(test.TestCase):
+
+ def _verifyClusterSpecEquality(self, cluster_spec, expected_proto):
+ """Verifies that the ClusterSpec generates the correct proto.
+
+ We are testing this four different ways to ensure that the ClusterSpec
+ returned by the TPUClusterResolver behaves identically to a normal
+ ClusterSpec when passed into the generic ClusterSpec libraries.
+
+ Args:
+ cluster_spec: ClusterSpec returned by the TPUClusterResolver
+ expected_proto: Expected protobuf
+ """
+ self.assertProtoEquals(expected_proto, cluster_spec.as_cluster_def())
+ self.assertProtoEquals(
+ expected_proto, server_lib.ClusterSpec(cluster_spec).as_cluster_def())
+ self.assertProtoEquals(
+ expected_proto,
+ server_lib.ClusterSpec(cluster_spec.as_cluster_def()).as_cluster_def())
+ self.assertProtoEquals(
+ expected_proto,
+ server_lib.ClusterSpec(cluster_spec.as_dict()).as_cluster_def())
+
+ def mock_service_client(
+ self,
+ tpu_map=None):
+
+ if tpu_map is None:
+ tpu_map = {}
+
+ def get_side_effect(name):
+ return tpu_map[name]
+
+ mock_client = mock.MagicMock()
+ mock_client.projects.locations.nodes.get.side_effect = get_side_effect
+ return mock_client
+
+ def testSimpleSuccessfulRetrieval(self):
+ tpu_map = {
+ 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
+ 'ipAddress': '10.1.2.3',
+ 'port': '8470'
+ }
+ }
+
+ tpu_cluster_resolver = TPUClusterResolver(
+ project='test-project',
+ zone='us-central1-c',
+ tpu_names=['test-tpu-1'],
+ credentials=None,
+ service=self.mock_service_client(tpu_map=tpu_map))
+
+ actual_cluster_spec = tpu_cluster_resolver.cluster_spec()
+ expected_proto = """
+ job { name: 'tpu_worker' tasks { key: 0 value: '10.1.2.3:8470' } }
+ """
+ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
+
+ def testMultipleSuccessfulRetrieval(self):
+ tpu_map = {
+ 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
+ 'ipAddress': '10.1.2.3',
+ 'port': '8470'
+ },
+ 'projects/test-project/locations/us-central1-c/nodes/test-tpu-2': {
+ 'ipAddress': '10.4.5.6',
+ 'port': '8470'
+ }
+ }
+
+ tpu_cluster_resolver = TPUClusterResolver(
+ project='test-project',
+ zone='us-central1-c',
+ tpu_names=['test-tpu-2', 'test-tpu-1'],
+ credentials=None,
+ service=self.mock_service_client(tpu_map=tpu_map))
+
+ actual_cluster_spec = tpu_cluster_resolver.cluster_spec()
+ expected_proto = """
+ job { name: 'tpu_worker' tasks { key: 0 value: '10.4.5.6:8470' }
+ tasks { key: 1 value: '10.1.2.3:8470' } }
+ """
+ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)