aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/distribute
diff options
context:
space:
mode:
authorGravatar Yuefeng Zhou <yuefengz@google.com>2018-08-16 13:34:54 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-16 13:41:59 -0700
commit389a02fe6d937f2d884a0ac76f325f91c142653a (patch)
treee4864678a8c698a1fe6bbaced6f21d8cc1854848 /tensorflow/python/distribute
parent9c50882415cb87a7eb81048d42401c64bf0617ef (diff)
Add multi_worker_util which has normalize_cluster_spec and is_chief.
PiperOrigin-RevId: 209037977
Diffstat (limited to 'tensorflow/python/distribute')
-rw-r--r--tensorflow/python/distribute/BUILD31
-rw-r--r--tensorflow/python/distribute/multi_worker_util.py80
-rw-r--r--tensorflow/python/distribute/multi_worker_util_test.py107
3 files changed, 218 insertions, 0 deletions
diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD
index 16fbe3f4b5..98ef9bf492 100644
--- a/tensorflow/python/distribute/BUILD
+++ b/tensorflow/python/distribute/BUILD
@@ -50,3 +50,34 @@ py_library(
srcs_version = "PY2AND3",
deps = [],
)
+
+py_library(
+ name = "multi_worker_util",
+ srcs = [
+ "multi_worker_util.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:training",
+ ],
+)
+
+py_test(
+ name = "multi_worker_util_test",
+ srcs = ["multi_worker_util_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":multi_worker_util",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python/eager:test",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
diff --git a/tensorflow/python/distribute/multi_worker_util.py b/tensorflow/python/distribute/multi_worker_util.py
new file mode 100644
index 0000000000..360733eff6
--- /dev/null
+++ b/tensorflow/python/distribute/multi_worker_util.py
@@ -0,0 +1,80 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities for multi-worker distribution strategies."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.core.protobuf import cluster_pb2
+from tensorflow.python.training import server_lib
+
+
+def normalize_cluster_spec(cluster_spec):
+ """Makes `cluster_spec` into a `ClusterSpec` object.
+
+ Args:
+ cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the
+ cluster configurations.
+
+ Returns:
+ a `ClusterSpec` object.
+
+ Raises:
+ ValueError: if `cluster_spec` is not a dict or a `ClusterSpec` or a
+ `ClusterDef`.
+ """
+ if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)):
+ return server_lib.ClusterSpec(cluster_spec)
+ elif not isinstance(cluster_spec, server_lib.ClusterSpec):
+ raise ValueError(
+ "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a "
+ "`tf.train.ClusterDef` object")
+ return cluster_spec
+
+
+def is_chief(cluster_spec, task_type, task_id):
+ """Returns whether the given task is chief in the cluster.
+
+ Args:
+ cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object specifying the
+ cluster configurations.
+ task_type: the task type in the cluster.
+ task_id: the task id in the cluster.
+
+ Returns:
+ a boolean indicating whether the given task is chief.
+
+ Raises:
+ ValueError: if `task_type` is not in the `cluster_spec` or `task_id` exceeds
+ the maximum id of the `task_type`.
+ """
+ cluster_spec = normalize_cluster_spec(cluster_spec)
+ if task_type not in cluster_spec.jobs:
+ raise ValueError(
+ "The task_type \"%s\" is not in the `cluster_spec`." % task_type)
+ if task_id >= cluster_spec.num_tasks(task_type):
+ raise ValueError("The `task_id` %d exceeds the maximum id of %s." % (
+ task_id, task_type))
+
+ if task_type == "chief":
+ return True
+
+ # If chief not in the cluster_spec, use the first worker as chief. This is
+ # common in CollectiveAllReduceStrategy.
+ if ("chief" not in cluster_spec.jobs and task_type == "worker" and
+ task_id == 0):
+ return True
+ return False
diff --git a/tensorflow/python/distribute/multi_worker_util_test.py b/tensorflow/python/distribute/multi_worker_util_test.py
new file mode 100644
index 0000000000..bdc49725c7
--- /dev/null
+++ b/tensorflow/python/distribute/multi_worker_util_test.py
@@ -0,0 +1,107 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for multi_worker_util."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.core.protobuf import cluster_pb2
+from tensorflow.python.distribute import multi_worker_util
+from tensorflow.python.eager import test
+from tensorflow.python.training import server_lib
+
+
+class NormalizeClusterSpecTest(test.TestCase):
+
+ def assert_same_cluster(self, lhs, rhs):
+ self.assertEqual(
+ server_lib.ClusterSpec(lhs).as_dict(),
+ server_lib.ClusterSpec(rhs).as_dict())
+
+ def testDictAsInput(self):
+ cluster_spec = {
+ "chief": ["127.0.0.1:1234"],
+ "worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
+ "ps": ["127.0.0.1:1926", "127.0.0.1:3141"]
+ }
+ self.assert_same_cluster(
+ cluster_spec, multi_worker_util.normalize_cluster_spec(cluster_spec))
+
+ def testClusterDefAsInput(self):
+ cluster_def = cluster_pb2.ClusterDef()
+ job = cluster_def.job.add()
+ job.name = "chief"
+ job.tasks[0] = "127.0.0.1:1234"
+
+ job = cluster_def.job.add()
+ job.name = "worker"
+ job.tasks[0] = "127.0.0.1:8964"
+ job.tasks[1] = "127.0.0.1:2333"
+
+ job = cluster_def.job.add()
+ job.name = "ps"
+ job.tasks[0] = "127.0.0.1:1926"
+ job.tasks[1] = "127.0.0.1:3141"
+
+ self.assert_same_cluster(
+ cluster_def, multi_worker_util.normalize_cluster_spec(cluster_def))
+
+ def testClusterSpecAsInput(self):
+ cluster_spec = server_lib.ClusterSpec({
+ "chief": ["127.0.0.1:1234"],
+ "worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
+ "ps": ["127.0.0.1:1926", "127.0.0.1:3141"]
+ })
+ self.assert_same_cluster(
+ cluster_spec, multi_worker_util.normalize_cluster_spec(cluster_spec))
+
+ def testUnexpectedInput(self):
+ cluster_spec = ["127.0.0.1:8964", "127.0.0.1:2333"]
+
+ with self.assertRaisesRegexp(
+ ValueError,
+ "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a "
+ "`tf.train.ClusterDef` object"):
+ multi_worker_util.normalize_cluster_spec(cluster_spec)
+
+
+class IsChiefTest(test.TestCase):
+
+ def testClusterWithChief(self):
+ cluster_spec = {
+ "chief": ["127.0.0.1:1234"],
+ "worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
+ "ps": ["127.0.0.1:1926", "127.0.0.1:3141"]
+ }
+ self.assertTrue(multi_worker_util.is_chief(cluster_spec, "chief", 0))
+ self.assertFalse(multi_worker_util.is_chief(cluster_spec, "worker", 0))
+
+ def testClusterWithoutChief(self):
+ cluster_spec = {"worker": ["127.0.0.1:8964", "127.0.0.1:2333"]}
+ self.assertTrue(multi_worker_util.is_chief(cluster_spec, "worker", 0))
+ self.assertFalse(multi_worker_util.is_chief(cluster_spec, "worker", 1))
+
+ with self.assertRaisesRegexp(
+ ValueError, "The task_type \"chief\" is not in the `cluster_spec`."):
+ multi_worker_util.is_chief(cluster_spec, "chief", 0)
+
+ with self.assertRaisesRegexp(
+ ValueError, "The `task_id` 2 exceeds the maximum id of worker."):
+ multi_worker_util.is_chief(cluster_spec, "worker", 2)
+
+
+if __name__ == "__main__":
+ test.main()