diff options
author | 2018-07-27 14:58:10 -0700 | |
---|---|---|
committer | 2018-07-27 15:02:45 -0700 | |
commit | 2beb3a9d8b9df294e7635cc23d195a76fd78de79 (patch) | |
tree | c53e0ba8f738093c5e97e5bc2ee6a5b064afcf2b /tensorflow/python/distribute | |
parent | 470b43af3153942e4ef838610aa07e93f904fd5f (diff) |
Add distribute_coordinator: a unified and split client for distributed traning.
PiperOrigin-RevId: 206378953
Diffstat (limited to 'tensorflow/python/distribute')
-rw-r--r-- | tensorflow/python/distribute/BUILD | 40 | ||||
-rw-r--r-- | tensorflow/python/distribute/distribute_coordinator.py | 361 | ||||
-rw-r--r-- | tensorflow/python/distribute/distribute_coordinator_test.py | 291 |
3 files changed, 692 insertions, 0 deletions
diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD new file mode 100644 index 0000000000..a29043d8b8 --- /dev/null +++ b/tensorflow/python/distribute/BUILD @@ -0,0 +1,40 @@ +licenses(["notice"]) # Apache 2.0 + +package( + default_visibility = [ + "//tensorflow:internal", + ], +) + +py_library( + name = "distribute_coordinator", + srcs = [ + "distribute_coordinator.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:training", + ], +) + +py_test( + name = "distribute_coordinator_test", + size = "small", + srcs = ["distribute_coordinator_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":distribute_coordinator", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:distributed_framework_test_lib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:session", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + ], +) diff --git a/tensorflow/python/distribute/distribute_coordinator.py b/tensorflow/python/distribute/distribute_coordinator.py new file mode 100644 index 0000000000..04c50dbafc --- /dev/null +++ b/tensorflow/python/distribute/distribute_coordinator.py @@ -0,0 +1,361 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A unified and split coordinator for distributed TensorFlow.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy +import json +import os +import threading + +from tensorflow.core.protobuf import cluster_pb2 +from tensorflow.python.training import server_lib + + +class _TaskType(object): + PS = "ps" + WORKER = "worker" + CHIEF = "chief" + EVALUATOR = "evaluator" + + +_coordinator_context = threading.local() + + +def get_current_coordinator_context(): + """Returns the current coordinator context.""" + try: + return _coordinator_context.current + except AttributeError: + return None + + +class _Barrier(object): + """A reusable barrier class for worker synchronization.""" + + def __init__(self, num_participants): + """Initializes the barrier object. + + Args: + num_participants: an integer which is the expected number of calls of + `wait` pass to through this barrier. + """ + self._num_participants = num_participants + self._counter = 0 + self._flag = False + self._local_sense = threading.local() + self._lock = threading.Lock() + self._condition = threading.Condition() + + def wait(self): + """Waits until all other callers reach the same wait call.""" + if not hasattr(self._local_sense, "value"): + self._local_sense.value = False + self._local_sense.value = not self._flag + with self._lock: + self._counter += 1 + if self._counter == self._num_participants: + self._counter = 0 + self._flag = self._local_sense.value + with self._condition: + while self._flag != self._local_sense.value: + self._condition.wait() + self._condition.notify_all() + + +def _get_num_workers(cluster_spec): + """Gets number of workers including chief.""" + if not cluster_spec: + return 0 + return len(cluster_spec.as_dict().get(_TaskType.WORKER, [])) + len( + cluster_spec.as_dict().get(_TaskType.CHIEF, [])) + + +class _CoordinatorContext(object): + """The coordinator context class. + + This context object provides configuration information for each task. One + context manager with a coordinator context object will be created per + invocation to the `worker_fn` where `get_current_coordinator_context` can be + called to access the coordinator context object. + """ + + def __init__(self, + cluster_spec, + task_type, + task_id, + between_graph=False, + rpc_layer="grpc", + worker_barrier=None): + """Initialize the coordinator context object. + + Args: + cluster_spec: a ClusterSpec object. It can be empty or None in the local + training case. + task_type: a string indicating the role of the corresponding task, such as + "worker" or "ps". It can be None if it is local training or + `between_graph` is False. + task_id: an integer indicating id of the corresponding task. It can be + None if it is local training or `between_graph` is False. + between_graph: whether it is between-graph replication or not. + rpc_layer: optional string specifying the RPC protocol for communication + with worker masters. If None or empty, hosts in the `cluster_spec` will + be used directly. + worker_barrier: optional, the barrier object for worker synchronization. + + Raises: + ValueError: if task_type or task_id is Node or empty and it is distributed + between-graph replicated training. + """ + if cluster_spec and between_graph: + if not task_type or task_id is None: + raise ValueError("`task_type` and `task_id` must be set in the " + "distributed between-graph replicated training.") + if task_type not in cluster_spec.jobs: + raise ValueError("`task_type` %r not found in the `cluster_spec` %r" % + (task_type, cluster_spec)) + self._cluster_spec = cluster_spec + self._task_type = task_type + self._task_id = task_id + self._worker_barrier = worker_barrier + self._rpc_layer = rpc_layer + self._master_target = self._get_master_target() + self._num_workers = _get_num_workers(cluster_spec) + self._is_chief_node = self._is_chief() + + def __enter__(self): + old_context = get_current_coordinator_context() + if old_context: + raise ValueError( + "You cannot run distribute coordinator in a `worker_fn`.") + _coordinator_context.current = self + + def __exit__(self, unused_exception_type, unused_exception_value, + unused_traceback): + _coordinator_context.current = None + + def _get_master_target(self): + """Return the master target for a task.""" + # If cluster_spec is None or empty, we use local master. + if not self._cluster_spec: + return "local" + + # If task_type is None, then it is in-graph replicated training. In this + # case we use the chief or first worker's master target. + if not self._task_type: + if _TaskType.CHIEF in self._cluster_spec.jobs: + assert not self.between_graph + task_type = _TaskType.CHIEF + task_id = 0 + else: + assert _TaskType.WORKER in self._cluster_spec.jobs + task_type = _TaskType.WORKER + task_id = 0 + else: + task_type = self._task_type + task_id = self._task_id + + prefix = "" + if self._rpc_layer: + prefix = self._rpc_layer + "://" + return prefix + self._cluster_spec.job_tasks(task_type)[task_id or 0] + + def _is_chief(self): + """Return whether the task is the chief worker.""" + if (not self._cluster_spec or self._task_type in [_TaskType.CHIEF, None]): + return True + + # If not local and chief not in the cluster_spec, use the first worker as + # chief. + if (_TaskType.CHIEF not in self._cluster_spec.jobs and + self._task_type == _TaskType.WORKER and self._task_id == 0): + return True + return False + + def wait_for_other_workers(self): + """Waits for other workers to reach the same call to this method. + + Raises: + ValueError: if `worker_barrier` is not passed to the __init__ method. + """ + if not self._worker_barrier: + raise ValueError( + "`worker_barrier is not set in the coordinator context.`") + self._worker_barrier.wait() + + @property + def distributed_mode(self): + """Whether it is distributed training or not.""" + return bool(self._cluster_spec) + + @property + def cluster_spec(self): + """Returns a copy of the cluster_spec object.""" + return copy.deepcopy(self._cluster_spec) + + @property + def task_type(self): + """Returns the role of the corresponing task.""" + return self._task_type + + @property + def task_id(self): + """Returns the id or index of the corresponing task.""" + return self._task_id + + @property + def master_target(self): + """Returns the session master for the corresponding task to connect to.""" + return self._master_target + + @property + def is_chief(self): + """Returns whether the task is a chief node.""" + return self._is_chief_node + + @property + def num_workers(self): + """Returns number of workers in the cluster, including chief.""" + return self._num_workers + + +def _run(worker_fn, cluster_spec, task_type, task_id, between_graph, rpc_layer, + worker_barrier): + with _CoordinatorContext(cluster_spec, task_type, task_id, between_graph, + rpc_layer, worker_barrier): + worker_fn() + + +def run_distribute_coordinator(worker_fn, + cluster_spec=None, + between_graph=False, + rpc_layer=None): + """Run the coordinator for distributed TensorFlow. + + This function runs a unified and split coordinator for distributed TensorFlow. + Given a `cluster_spec` specifying server addresses and their roles in a + cluster, this coordinator will figure out how to set them up, give the + underlying function the right targets for master sessions and coordinate their + training. + + In addition to be the distribute coordinator, this is also the source of + configurations for each job in the distributed training. As there are multiple + ways to configure a distributed TensorFlow cluster, its context object + provides these configurations so that users or higher-level APIs don't have to + figure out the configuration for each job by themselves. + + In the between-graph replicated training, this coordinator will create + multiple threads and each calls the `worker_fn` which is supposed to create + its own graph and connect to one worker master given by its coordinator + context. In the in-graph replicated training, it has only one thread calling + this `worker_fn`. + + The `worker_fn` defines the training logic and is called under a its own + coordinator context which can be accessed to via + `get_current_coordinator_context`. A coordinator context provides access to + configurations for each task, e.g. the task_type, task_id, master target and + so on. Since `worker_fn` will be called in a thread and possibly multiple + times, caller should be careful when it accesses global data. For example, it + is unsafe to define flags in a `worker_fn` or to define different environment + variables for different `worker_fn`s. + + The `worker_fn` for the between-graph replication is defined as if there are + only one worker corresponding to the `worker_fn` and possibly ps jobs. It + assigns variables to parameter servers and all other operations to that + worker. In the in-graph replication case, the `worker_fn` has to define + operations for all worker jobs. Using a distribution strategy can simplify the + `worker_fn` by not having to worry about the replication and device assignment + of variables and operations. + + This method is intended to be invoked by high-level APIs so that users don't + have to explictly call it to run this coordinator. For those who don't use + high-level APIs, to change a program to use this coordinator, wrap everything + in a the program after global data definitions such as commandline flag + definition into the `worker_fn` and get task-specific configurations from + the coordinator context. + + The `cluster_spec` can be either passed by the argument or parsed from the + "TF_CONFIG" envrionment variable. Example of a TF_CONFIG: + ``` + cluster = {'chief': ['host0:2222'], + 'ps': ['host1:2222', 'host2:2222'], + 'worker': ['host3:2222', 'host4:2222', 'host5:2222']} + os.environ['TF_CONFIG'] = json.dumps({'cluster': cluster}) + ``` + + If `cluster_spec` is not given in any format, it becomes local training and + this coordinator will connect to a local session. + + For evaluation, if "evaluator" exist in the cluster_spec, a separate thread + will be created with its `task_type` set to "evaluator". If "evaluator" is not + set in the cluster_spec, it entirely depends on the `worker_fn` for how to do + evaluation. + + Args: + worker_fn: the function to be called and given the access to a coordinator + context object. + cluster_spec: a dict, ClusterDef or ClusterSpec specifying servers and roles + in a cluster. If not set or empty, fall back to local training. + between_graph: a boolean. It is only useful when `cluster_spec` is set and + not empty. If true, it will use between-graph replicated training; + otherwise it will use in-graph replicated training. + rpc_layer: optional string, the protocol for RPC, e.g. "grpc". + + Raises: + ValueError: if `cluster_spec` is supplied but not a dict or a ClusterDef or + a ClusterSpec. + """ + if not cluster_spec: + tf_config = json.loads(os.environ.get("TF_CONFIG", "{}")) + cluster_spec = tf_config.get("cluster", {}) + + if cluster_spec: + if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)): + cluster_spec = server_lib.ClusterSpec(cluster_spec) + elif not isinstance(cluster_spec, server_lib.ClusterSpec): + raise ValueError( + "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a " + "`tf.train.ClusterDef` object") + # TODO(yuefengz): validate cluster_spec. + + threads = [] + if cluster_spec and _TaskType.EVALUATOR in cluster_spec.jobs: + t = threading.Thread( + target=_run, + args=(worker_fn, cluster_spec, _TaskType.EVALUATOR, 0, between_graph, + rpc_layer, None)) + t.start() + threads.append(t) + + if cluster_spec and between_graph: + worker_barrier = _Barrier(_get_num_workers(cluster_spec)) + for task_type in [_TaskType.CHIEF, _TaskType.WORKER]: + for task_id in range(len(cluster_spec.as_dict().get(task_type, []))): + t = threading.Thread( + target=_run, + args=(worker_fn, cluster_spec, task_type, task_id, between_graph, + rpc_layer, worker_barrier)) + t.start() + threads.append(t) + else: + # Local or in-graph replicated training. + _run(worker_fn, cluster_spec, None, None, between_graph, rpc_layer, None) + + # TODO(yuefengz): wrapper threads into thread coordinator? + for t in threads: + t.join() diff --git a/tensorflow/python/distribute/distribute_coordinator_test.py b/tensorflow/python/distribute/distribute_coordinator_test.py new file mode 100644 index 0000000000..82fd823352 --- /dev/null +++ b/tensorflow/python/distribute/distribute_coordinator_test.py @@ -0,0 +1,291 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for distribute coordinator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib +import copy +import threading +import six + +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.client import session +from tensorflow.python.distribute import distribute_coordinator +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + +CHIEF = distribute_coordinator._TaskType.CHIEF +WORKER = distribute_coordinator._TaskType.WORKER +PS = distribute_coordinator._TaskType.PS +EVALUATOR = distribute_coordinator._TaskType.EVALUATOR + +NUM_WORKERS = 3 +NUM_PS = 2 + + +def _bytes_to_str(maybe_bytes): + if isinstance(maybe_bytes, six.string_types): + return maybe_bytes + else: + return str(maybe_bytes, "utf-8") + + +class DistributeCoordinatorTest(test.TestCase): + + @classmethod + def setUpClass(cls): + # We have to create a global in-process cluster because once an in-process + # tensorflow server is created, there is no way to terminate it. Please see + # multi_worker_test_base.py for more details. + cls._workers, cls._ps = test_util.create_local_cluster( + NUM_WORKERS, num_ps=NUM_PS) + cls._cluster_spec = { + WORKER: [_bytes_to_str(w.target) for w in cls._workers], + PS: [_bytes_to_str(ps.target) for ps in cls._ps] + } + + def setUp(self): + self._result_correct = 0 + self._lock = threading.Lock() + self._task_context = {} + + @contextlib.contextmanager + def _test_session(self, target): + config = config_pb2.ConfigProto(allow_soft_placement=True) + config.graph_options.optimizer_options.opt_level = -1 + with session.Session(graph=None, config=config, target=target) as sess: + yield sess + + def _in_graph_worker_fn(self): + context = distribute_coordinator.get_current_coordinator_context() + self.assertTrue(context is not None) + with self._test_session(target=context.master_target) as sess: + xs = [] + expected = 0.0 + for i in range(context.num_workers): + with ops.device("/job:worker/task:%d" % i): + x = variable_scope.get_variable("x_%d" % i, initializer=10.0) + x_add = x.assign_add(float(i)) + xs.append(x_add) + expected += i + 10.0 + + with ops.device("/job:worker/task:0"): + result = math_ops.add_n(xs) + + variables.global_variables_initializer().run() + result_value = sess.run(result) + self.assertEqual(result_value, expected) + if result_value == expected: + self._result_correct += 1 + + def testInGraph(self): + """Test it runs in-graph replicated training correctly.""" + distribute_coordinator.run_distribute_coordinator( + self._in_graph_worker_fn, + cluster_spec=self._cluster_spec, + between_graph=False) + self.assertEqual(self._result_correct, 1) + + def _between_graph_worker_fn(self): + context = distribute_coordinator.get_current_coordinator_context() + self.assertTrue(context is not None) + with self._test_session(target=context.master_target) as sess: + with ops.device("/job:ps/task:0"): + # TODO(yuefengz): investigate why not using resource variable will make + # the test flaky. + x = variable_scope.get_variable( + "x", initializer=10.0, use_resource=True) + with ops.device("/job:ps/task:1"): + y = variable_scope.get_variable( + "y", initializer=20.0, use_resource=True) + + x_add = x.assign_add(2.0) + y_sub = y.assign_sub(2.0) + train_op = control_flow_ops.group([x_add, y_sub]) + + if context.is_chief: + variables.global_variables_initializer().run() + + # Synchronize workers after initializaton. + context.wait_for_other_workers() + + sess.run(train_op) + + # Synchronize workers after one step to make sure they all have finished + # training. + context.wait_for_other_workers() + + x_val, y_val = sess.run([x, y]) + + self.assertEqual(x_val, 16.0) + self.assertEqual(y_val, 14.0) + if x_val == 16.0 and y_val == 14.0: + with self._lock: + self._result_correct += 1 + + def testBetweenGraph(self): + """Test it runs between-graph replicated training correctly.""" + distribute_coordinator.run_distribute_coordinator( + self._between_graph_worker_fn, + cluster_spec=self._cluster_spec, + between_graph=True) + + # Each finished worker will increment self._result_correct. + self.assertEqual(self._result_correct, NUM_WORKERS) + + def _dump_task_context(self): + """Dumps the propoerties of each coordinator context. + + It dumps the context properties to a dict mapping from task_type to a list + of tuples of master_target, num_workers, is_chief and distribute_mode, where + the list is indexed by the task_id. + """ + context = distribute_coordinator.get_current_coordinator_context() + self.assertTrue(context is not None) + task_type = str(context.task_type) + task_id = context.task_id or 0 + with self._lock: + if task_type not in self._task_context: + self._task_context[task_type] = [] + while len(self._task_context[task_type]) <= task_id: + self._task_context[task_type].append(None) + self._task_context[task_type][task_id] = (context.master_target, + context.num_workers, + context.is_chief, + context.distributed_mode) + + def testBetweenGraphContext(self): + # Dumps the task contexts to the self._task_context dict. + distribute_coordinator.run_distribute_coordinator( + self._dump_task_context, + cluster_spec=self._cluster_spec, + between_graph=True) + + # There is only one type of task and there three such tasks. + self.assertEqual(len(self._task_context), 1) + self.assertTrue(WORKER in self._task_context) + self.assertEqual(len(self._task_context[WORKER]), NUM_WORKERS) + + # Check whether each task has the right master_target, num_workers, is_chief + # and distributed_mode. + self.assertEqual( + self._task_context[WORKER][0], + (_bytes_to_str(self._workers[0].target), NUM_WORKERS, True, True)) + self.assertEqual( + self._task_context[WORKER][1], + (_bytes_to_str(self._workers[1].target), NUM_WORKERS, False, True)) + self.assertEqual( + self._task_context[WORKER][2], + (_bytes_to_str(self._workers[2].target), NUM_WORKERS, False, True)) + + def testInGraphContext(self): + # Dumps the task contexts to the self._task_context dict. + distribute_coordinator.run_distribute_coordinator( + self._dump_task_context, + cluster_spec=self._cluster_spec, + between_graph=False) + + # There is only a "None" task in the dumped task context. + self.assertEqual(len(self._task_context), 1) + self.assertTrue("None" in self._task_context) + self.assertEqual(len(self._task_context["None"]), 1) + + # Check whether each task has the right master_target, num_workers, is_chief + # and distributed_mode. + self.assertEqual( + self._task_context["None"][0], + (_bytes_to_str(self._workers[0].target), NUM_WORKERS, True, True)) + + def testLocalContext(self): + # Dumps the task contexts to the self._task_context dict. + distribute_coordinator.run_distribute_coordinator( + self._dump_task_context, cluster_spec=None, between_graph=True) + + # There is only a "None" task. + self.assertEqual(len(self._task_context), 1) + self.assertTrue("None" in self._task_context) + self.assertEqual(len(self._task_context["None"]), 1) + + # Check whether each task has the right master_target, num_workers, is_chief + # and distributed_mode. + self.assertEqual(self._task_context["None"][0], ("local", 0, True, False)) + + def testBetweenGraphContextWithChief(self): + # Adds a chief node, so there are NUM_WORKERS + 1 workers in total. + cluster_spec = copy.deepcopy(self._cluster_spec) + cluster_spec[CHIEF] = ["fake_chief"] + + # Dumps the task contexts to the self._task_context dict. + distribute_coordinator.run_distribute_coordinator( + self._dump_task_context, + cluster_spec=cluster_spec, + between_graph=True, + rpc_layer="grpc") + + # There are one CHIEF and three workers. + self.assertEqual(len(self._task_context), 2) + self.assertTrue(CHIEF in self._task_context) + self.assertTrue(WORKER in self._task_context) + self.assertEqual(len(self._task_context[CHIEF]), 1) + self.assertEqual(len(self._task_context[WORKER]), NUM_WORKERS) + + # Check whether each task has the right master_target, num_workers, is_chief + # and distributed_mode. + self.assertEqual(self._task_context[CHIEF][0], + ("grpc://fake_chief", 4, True, True)) + self.assertEqual(self._task_context[WORKER][0], + ("grpc://" + _bytes_to_str(self._workers[0].target), + NUM_WORKERS + 1, False, True)) + self.assertEqual(self._task_context[WORKER][1], + ("grpc://" + _bytes_to_str(self._workers[1].target), + NUM_WORKERS + 1, False, True)) + self.assertEqual(self._task_context[WORKER][2], + ("grpc://" + _bytes_to_str(self._workers[2].target), + NUM_WORKERS + 1, False, True)) + + def testInGraphContextWithEval(self): + # Adds a EVALUATOR job. + cluster_spec = copy.deepcopy(self._cluster_spec) + cluster_spec[EVALUATOR] = ["fake_evaluator"] + + # Dumps the task contexts to the self._task_context dict. + distribute_coordinator.run_distribute_coordinator( + self._dump_task_context, cluster_spec=cluster_spec, between_graph=False) + + # There are one "None" task and one EVALUATOR task. + self.assertEqual(len(self._task_context), 2) + self.assertTrue("None" in self._task_context) + self.assertTrue(EVALUATOR in self._task_context) + self.assertEqual(len(self._task_context["None"]), 1) + self.assertEqual(len(self._task_context[EVALUATOR]), 1) + + # Check whether each task has the right master_target, num_workers, is_chief + # and distributed_mode. + self.assertEqual(self._task_context["None"][0], + (_bytes_to_str(self._workers[0].target), 3, True, True)) + self.assertEqual(self._task_context[EVALUATOR][0], + ("fake_evaluator", 3, False, True)) + + +if __name__ == "__main__": + test.main() |