diff options
author | Yuefeng Zhou <yuefengz@google.com> | 2018-08-16 13:34:54 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-16 13:41:59 -0700 |
commit | 389a02fe6d937f2d884a0ac76f325f91c142653a (patch) | |
tree | e4864678a8c698a1fe6bbaced6f21d8cc1854848 /tensorflow/python/distribute | |
parent | 9c50882415cb87a7eb81048d42401c64bf0617ef (diff) |
Add multi_worker_util which has normalize_cluster_spec and is_chief.
PiperOrigin-RevId: 209037977
Diffstat (limited to 'tensorflow/python/distribute')
-rw-r--r-- | tensorflow/python/distribute/BUILD | 31 | ||||
-rw-r--r-- | tensorflow/python/distribute/multi_worker_util.py | 80 | ||||
-rw-r--r-- | tensorflow/python/distribute/multi_worker_util_test.py | 107 |
3 files changed, 218 insertions, 0 deletions
diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 16fbe3f4b5..98ef9bf492 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -50,3 +50,34 @@ py_library( srcs_version = "PY2AND3", deps = [], ) + +py_library( + name = "multi_worker_util", + srcs = [ + "multi_worker_util.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:training", + ], +) + +py_test( + name = "multi_worker_util_test", + srcs = ["multi_worker_util_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":multi_worker_util", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:constant_op", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:training", + "//tensorflow/python/eager:test", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) diff --git a/tensorflow/python/distribute/multi_worker_util.py b/tensorflow/python/distribute/multi_worker_util.py new file mode 100644 index 0000000000..360733eff6 --- /dev/null +++ b/tensorflow/python/distribute/multi_worker_util.py @@ -0,0 +1,80 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for multi-worker distribution strategies.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.core.protobuf import cluster_pb2 +from tensorflow.python.training import server_lib + + +def normalize_cluster_spec(cluster_spec): + """Makes `cluster_spec` into a `ClusterSpec` object. + + Args: + cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the + cluster configurations. + + Returns: + a `ClusterSpec` object. + + Raises: + ValueError: if `cluster_spec` is not a dict or a `ClusterSpec` or a + `ClusterDef`. + """ + if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)): + return server_lib.ClusterSpec(cluster_spec) + elif not isinstance(cluster_spec, server_lib.ClusterSpec): + raise ValueError( + "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a " + "`tf.train.ClusterDef` object") + return cluster_spec + + +def is_chief(cluster_spec, task_type, task_id): + """Returns whether the given task is chief in the cluster. + + Args: + cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object specifying the + cluster configurations. + task_type: the task type in the cluster. + task_id: the task id in the cluster. + + Returns: + a boolean indicating whether the given task is chief. + + Raises: + ValueError: if `task_type` is not in the `cluster_spec` or `task_id` exceeds + the maximum id of the `task_type`. + """ + cluster_spec = normalize_cluster_spec(cluster_spec) + if task_type not in cluster_spec.jobs: + raise ValueError( + "The task_type \"%s\" is not in the `cluster_spec`." % task_type) + if task_id >= cluster_spec.num_tasks(task_type): + raise ValueError("The `task_id` %d exceeds the maximum id of %s." % ( + task_id, task_type)) + + if task_type == "chief": + return True + + # If chief not in the cluster_spec, use the first worker as chief. This is + # common in CollectiveAllReduceStrategy. + if ("chief" not in cluster_spec.jobs and task_type == "worker" and + task_id == 0): + return True + return False diff --git a/tensorflow/python/distribute/multi_worker_util_test.py b/tensorflow/python/distribute/multi_worker_util_test.py new file mode 100644 index 0000000000..bdc49725c7 --- /dev/null +++ b/tensorflow/python/distribute/multi_worker_util_test.py @@ -0,0 +1,107 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for multi_worker_util.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.core.protobuf import cluster_pb2 +from tensorflow.python.distribute import multi_worker_util +from tensorflow.python.eager import test +from tensorflow.python.training import server_lib + + +class NormalizeClusterSpecTest(test.TestCase): + + def assert_same_cluster(self, lhs, rhs): + self.assertEqual( + server_lib.ClusterSpec(lhs).as_dict(), + server_lib.ClusterSpec(rhs).as_dict()) + + def testDictAsInput(self): + cluster_spec = { + "chief": ["127.0.0.1:1234"], + "worker": ["127.0.0.1:8964", "127.0.0.1:2333"], + "ps": ["127.0.0.1:1926", "127.0.0.1:3141"] + } + self.assert_same_cluster( + cluster_spec, multi_worker_util.normalize_cluster_spec(cluster_spec)) + + def testClusterDefAsInput(self): + cluster_def = cluster_pb2.ClusterDef() + job = cluster_def.job.add() + job.name = "chief" + job.tasks[0] = "127.0.0.1:1234" + + job = cluster_def.job.add() + job.name = "worker" + job.tasks[0] = "127.0.0.1:8964" + job.tasks[1] = "127.0.0.1:2333" + + job = cluster_def.job.add() + job.name = "ps" + job.tasks[0] = "127.0.0.1:1926" + job.tasks[1] = "127.0.0.1:3141" + + self.assert_same_cluster( + cluster_def, multi_worker_util.normalize_cluster_spec(cluster_def)) + + def testClusterSpecAsInput(self): + cluster_spec = server_lib.ClusterSpec({ + "chief": ["127.0.0.1:1234"], + "worker": ["127.0.0.1:8964", "127.0.0.1:2333"], + "ps": ["127.0.0.1:1926", "127.0.0.1:3141"] + }) + self.assert_same_cluster( + cluster_spec, multi_worker_util.normalize_cluster_spec(cluster_spec)) + + def testUnexpectedInput(self): + cluster_spec = ["127.0.0.1:8964", "127.0.0.1:2333"] + + with self.assertRaisesRegexp( + ValueError, + "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a " + "`tf.train.ClusterDef` object"): + multi_worker_util.normalize_cluster_spec(cluster_spec) + + +class IsChiefTest(test.TestCase): + + def testClusterWithChief(self): + cluster_spec = { + "chief": ["127.0.0.1:1234"], + "worker": ["127.0.0.1:8964", "127.0.0.1:2333"], + "ps": ["127.0.0.1:1926", "127.0.0.1:3141"] + } + self.assertTrue(multi_worker_util.is_chief(cluster_spec, "chief", 0)) + self.assertFalse(multi_worker_util.is_chief(cluster_spec, "worker", 0)) + + def testClusterWithoutChief(self): + cluster_spec = {"worker": ["127.0.0.1:8964", "127.0.0.1:2333"]} + self.assertTrue(multi_worker_util.is_chief(cluster_spec, "worker", 0)) + self.assertFalse(multi_worker_util.is_chief(cluster_spec, "worker", 1)) + + with self.assertRaisesRegexp( + ValueError, "The task_type \"chief\" is not in the `cluster_spec`."): + multi_worker_util.is_chief(cluster_spec, "chief", 0) + + with self.assertRaisesRegexp( + ValueError, "The `task_id` 2 exceeds the maximum id of worker."): + multi_worker_util.is_chief(cluster_spec, "worker", 2) + + +if __name__ == "__main__": + test.main() |