From 493d7588172bcf476309b3954db342839ca37872 Mon Sep 17 00:00:00 2001 From: Yuefeng Zhou Date: Thu, 2 Aug 2018 21:17:31 -0700 Subject: Add the CollectiveAllReduceStrategy. PiperOrigin-RevId: 207215423 --- tensorflow/contrib/distribute/python/BUILD | 56 +++++- .../python/collective_all_reduce_strategy.py | 205 +++++++++++++++++++ .../python/collective_all_reduce_strategy_test.py | 217 +++++++++++++++++++++ .../contrib/distribute/python/cross_tower_ops.py | 113 ++++++++++- .../distribute/python/cross_tower_ops_test.py | 172 ++++++++++++++++ .../contrib/distribute/python/cross_tower_utils.py | 151 +++++++++++++- .../contrib/distribute/python/mirrored_strategy.py | 138 +++++++------ .../distribute/python/multi_worker_test_base.py | 7 + 8 files changed, 984 insertions(+), 75 deletions(-) create mode 100644 tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py create mode 100644 tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index f6cc1dcc02..ecfe044b4f 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -133,6 +133,24 @@ py_library( ], ) +py_library( + name = "collective_all_reduce_strategy", + srcs = ["collective_all_reduce_strategy.py"], + visibility = ["//tensorflow:internal"], + deps = [ + ":cross_tower_ops", + ":cross_tower_utils", + ":mirrored_strategy", + ":values", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:collective_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python/eager:context", + ], +) + py_library( name = "strategy_test_lib", testonly = 1, @@ -327,6 +345,37 @@ py_library( ], ) +py_test( + name = "collective_all_reduce_strategy_test", + srcs = ["collective_all_reduce_strategy_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + ], + deps = [ + ":collective_all_reduce_strategy", + ":combinations", + ":cross_tower_utils", + ":multi_worker_test_base", + ":strategy_test_lib", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:init_ops", + "//tensorflow/python:layers", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/eager:context", + "//tensorflow/python/estimator:run_config", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + py_library( name = "minimize_loss_test_lib", testonly = 1, @@ -497,8 +546,11 @@ py_library( "//tensorflow/contrib/all_reduce:all_reduce_py", "//tensorflow/contrib/nccl:nccl_py", "//tensorflow/python:array_ops", + "//tensorflow/python:collective_ops", + "//tensorflow/python:device", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", + "//tensorflow/python:gradients", "//tensorflow/python:math_ops", ], ) @@ -533,7 +585,9 @@ py_library( "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", "//tensorflow/python:platform", + "//tensorflow/python:resource_variable_ops", "//tensorflow/python:training", + "//tensorflow/python:variable_scope", "//tensorflow/python/eager:context", "@six_archive//:six", ], @@ -541,6 +595,7 @@ py_library( cuda_py_test( name = "cross_tower_ops_test", + size = "large", srcs = ["cross_tower_ops_test.py"], additional_deps = [ ":combinations", @@ -555,7 +610,6 @@ cuda_py_test( "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", ], - shard_count = 15, tags = [ "multi_and_single_gpu", "no_pip", diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py new file mode 100644 index 0000000000..9afcaecf78 --- /dev/null +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py @@ -0,0 +1,205 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Class CollectiveAllReduceStrategy implementing DistributionStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import os + +from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib +from tensorflow.contrib.distribute.python import cross_tower_utils +from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import values +from tensorflow.core.protobuf import cluster_pb2 +from tensorflow.python.eager import context +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import collective_ops +from tensorflow.python.training import server_lib + + +# TODO(yuefengz): move this function to a common util file. +def _normalize_cluster_spec(cluster_spec): + if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)): + return server_lib.ClusterSpec(cluster_spec) + elif not isinstance(cluster_spec, server_lib.ClusterSpec): + raise ValueError( + "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a " + "`tf.train.ClusterDef` object") + return cluster_spec + + +# TODO(yuefengz): shard the dataset. +# TODO(yuefengz): support in-graph replication. +# TODO(yuefengz): it only works with a cluster without a chief node, maybe +# support chief node? +class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): + """Distribution strategy that uses collective ops for all-reduce. + + It is similar to the MirroredStrategy but it uses collective ops for + reduction. It currently only works for between-graph replication and its + reduction will reduce across all workers. + """ + + def __init__(self, + num_gpus_per_worker=0, + cluster_spec=None, + task_type="worker", + task_id=0): + """Initializes the object. + + Args: + num_gpus_per_worker: number of local GPUs or GPUs per worker. + cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the + cluster configurations. + task_type: the current task type, such as "worker". + task_id: the current task id. + + Raises: + ValueError: if `task_type` is not in the `cluster_spec`. + """ + self._num_gpus_per_worker = num_gpus_per_worker + self._initialize(cluster_spec, task_type, task_id) + + def _initialize(self, cluster_spec, task_type, task_id): + if task_type not in ["chief", "worker"]: + raise ValueError( + "Unrecognized task_type: %r, valid task types are: \"chief\", " + "\"worker\"." % task_type) + if cluster_spec: + self._cluster_spec = _normalize_cluster_spec(cluster_spec) + worker_device = "/job:%s/task:%d" % (task_type, task_id) + num_workers = len(self._cluster_spec.as_dict().get(task_type, [])) + if "chief" in self._cluster_spec.as_dict(): + num_workers += 1 + if not num_workers: + raise ValueError("`task_type` shoud be in `cluster_spec`.") + + # TODO(yuefengz): create a utility to infer chief. + if "chief" in self._cluster_spec.as_dict() and task_type == "chief": + assert task_id == 0 + self._is_chief = True + else: + assert task_type == "worker" + self._is_chief = task_id == 0 + else: + self._cluster_spec = None + self._is_chief = True + worker_device = "" + num_workers = 1 + self._num_workers = num_workers + + if self._num_gpus_per_worker: + local_devices = [ + "%s/device:GPU:%d" % (worker_device, i) + for i in range(self._num_gpus_per_worker) + ] + else: + local_devices = [worker_device] + + self._collective_keys = cross_tower_utils.CollectiveKeys() + super(CollectiveAllReduceStrategy, self).__init__( + devices=local_devices, + cross_tower_ops=cross_tower_ops_lib.CollectiveAllReduce( + num_workers=num_workers, + num_gpus_per_worker=self._num_gpus_per_worker, + collective_keys=self._collective_keys)) + + # Add a default device so that ops without specified devices will not end up + # on other workers. + if cluster_spec: + self._default_device = "/job:%s/replica:0/task:%d" % (task_type, task_id) + + def _create_variable(self, next_creator, *args, **kwargs): + colocate_with = kwargs.pop("colocate_with", None) + devices = self._get_devices_from(colocate_with) + group_size = len(devices) * self._num_workers + group_key = self._collective_keys.get_group_key(self._devices) + + def _real_mirrored_creator(devices, *args, **kwargs): + """Creates one MirroredVariable on the current worker.""" + index = {} + collective_instance_key = self._collective_keys.get_instance_key( + key_id=kwargs["name"]) + if "initial_value" not in kwargs: + raise ValueError("Initial value must be specified.") + initial_value = kwargs["initial_value"] + if callable(initial_value): + initial_value_fn = initial_value + else: + initial_value_fn = lambda: initial_value + + for i, d in enumerate(devices): + with ops.device(d): + if i > 0: + # Give replicas meaningful distinct names: + var0name = index[devices[0]].name.split(":")[0] + # We append a / to variable names created on towers with id > 0 to + # ensure that we ignore the name scope and instead use the given + # name as the absolute name of the variable. + kwargs["name"] = "%s/replica_%d/" % (var0name, i) + + # The initial value fn makes sure variables all initialized to + # same values. The first device of the chief worker will send their + # variable values to other devices and other workers. + def _overridden_initial_value_fn(device=d, index=i): # pylint: disable=g-missing-docstring + with ops.device(device): + initial_value = initial_value_fn() + assert not callable(initial_value) + initial_value = ops.convert_to_tensor(initial_value) + + if self._is_chief and index == 0: + bcast_send = collective_ops.broadcast_send( + initial_value, initial_value.shape, initial_value.dtype, + group_size, group_key, collective_instance_key) + with ops.control_dependencies([bcast_send]): + return array_ops.identity(initial_value) + else: + return collective_ops.broadcast_recv( + initial_value.shape, initial_value.dtype, group_size, + group_key, collective_instance_key) + + kwargs["initial_value"] = _overridden_initial_value_fn + + with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): + v = next_creator(*args, **kwargs) + + assert not isinstance(v, values.DistributedVariable) + index[d] = v + return index + + # pylint: disable=protected-access + return mirrored_strategy._create_mirrored_variable( + devices, _real_mirrored_creator, *args, **kwargs) + + def configure(self, session_config=None): + # Use TF_CONFIG to get the cluster spec and the current job. + if not self._cluster_spec: + tf_config = json.loads(os.environ.get("TF_CONFIG", "{}")) + cluster_spec = _normalize_cluster_spec(tf_config.get("cluster", {})) + + task_env = tf_config.get("task", {}) + if task_env: + task_type = task_env.get("type", "worker") + task_id = int(task_env.get("index", "0")) + else: + task_type = "worker" + task_id = 0 + + if cluster_spec: + self._initialize(cluster_spec, task_type, task_id) diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py new file mode 100644 index 0000000000..b5e54e3b7d --- /dev/null +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py @@ -0,0 +1,217 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for CollectiveAllReduceStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.contrib.distribute.python import collective_all_reduce_strategy +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import cross_tower_utils +from tensorflow.contrib.distribute.python import multi_worker_test_base +from tensorflow.contrib.distribute.python import strategy_test_lib +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.eager import context +from tensorflow.python.estimator import run_config +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.layers import core +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +class DistributedCollectiveAllReduceStrategyTest( + multi_worker_test_base.MultiWorkerTestBase, parameterized.TestCase): + + collective_key_base = 0 + + @classmethod + def setUpClass(cls): + """Create a local cluster with 2 workers.""" + cls._workers, cls._ps = multi_worker_test_base.create_in_process_cluster( + num_workers=3, num_ps=0) + cls._cluster_spec = { + run_config.TaskType.WORKER: [ + 'fake_worker_0', 'fake_worker_1', 'fake_worker_2' + ] + } + + def setUp(self): + self._run_options = config_pb2.RunOptions() + self._run_options.experimental.collective_graph_key = 6 + + self._sess_config = config_pb2.ConfigProto() + self._sess_config.experimental.collective_group_leader = ( + '/job:worker/replica:0/task:0') + + # We use a different key_base for each test so that collective keys won't be + # reused. + # TODO(yuefengz, tucker): enable it to reuse collective keys in different + # tests. + DistributedCollectiveAllReduceStrategyTest.collective_key_base += 100000 + super(DistributedCollectiveAllReduceStrategyTest, self).setUp() + + def _get_test_object(self, task_type, task_id, num_gpus=0): + distribution = collective_all_reduce_strategy.CollectiveAllReduceStrategy( + num_gpus_per_worker=num_gpus, + cluster_spec=self._cluster_spec, + task_type=task_type, + task_id=task_id) + collective_keys = cross_tower_utils.CollectiveKeys( + group_key_start=10 * num_gpus + + DistributedCollectiveAllReduceStrategyTest.collective_key_base, + instance_key_start=num_gpus * 100 + + DistributedCollectiveAllReduceStrategyTest.collective_key_base, + instance_key_with_id_start=num_gpus * 10000 + + DistributedCollectiveAllReduceStrategyTest.collective_key_base) + distribution._collective_keys = collective_keys + distribution._cross_tower_ops._collective_keys = collective_keys + return distribution, self._workers[task_id].target + + def _test_minimize_loss_graph(self, task_type, task_id, num_gpus): + d, master_target = self._get_test_object(task_type, task_id, num_gpus) + with ops.Graph().as_default(), \ + self.test_session(config=self._sess_config, + target=master_target) as sess, \ + d.scope(): + l = core.Dense(1, use_bias=False, name='gpu_%d' % d._num_gpus_per_worker) + + def loss_fn(x): + y = array_ops.reshape(l(x), []) - constant_op.constant(1.) + return y * y + + # TODO(yuefengz, apassos): eager.backprop.implicit_grad is not safe for + # multiple graphs (b/111216820). + def grad_fn(x): + loss = loss_fn(x) + var_list = ( + variables.trainable_variables() + ops.get_collection( + ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) + grads = gradients.gradients(loss, var_list) + ret = list(zip(grads, var_list)) + return ret + + def update(v, g): + return v.assign_sub(0.05 * g, use_locking=True) + + one = d.broadcast(constant_op.constant([[1.]])) + + def step(): + """Perform one optimization step.""" + # Run forward & backward to get gradients, variables list. + g_v = d.call_for_each_tower(grad_fn, one) + # Update the variables using the gradients and the update() function. + before_list = [] + after_list = [] + for g, v in g_v: + fetched = d.read_var(v) + before_list.append(fetched) + with ops.control_dependencies([fetched]): + # TODO(yuefengz): support non-Mirrored variable as destinations. + g = d.reduce( + variable_scope.VariableAggregation.SUM, g, destinations=v) + with ops.control_dependencies(d.unwrap(d.update(v, update, g))): + after_list.append(d.read_var(v)) + return before_list, after_list + + before_out, after_out = step() + + if context.num_gpus() < d._num_gpus_per_worker: + return True + + sess.run( + variables.global_variables_initializer(), options=self._run_options) + + for i in range(10): + b, a = sess.run((before_out, after_out), options=self._run_options) + if i == 0: + before, = b + after, = a + + error_before = abs(before - 1) + error_after = abs(after - 1) + # Error should go down + self.assertLess(error_after, error_before) + return error_after < error_before + + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) + def testMinimizeLossGraph(self, num_gpus): + self._run_between_graph_clients(self._test_minimize_loss_graph, + self._cluster_spec, num_gpus) + + def _test_variable_initialization(self, task_type, task_id, num_gpus): + distribution, master_target = self._get_test_object(task_type, task_id, + num_gpus) + with ops.Graph().as_default(), \ + self.test_session(config=self._sess_config, + target=master_target) as sess, \ + distribution.scope(): + + def model_fn(): + x = variable_scope.get_variable( + 'x', + shape=(2, 3), + initializer=init_ops.random_uniform_initializer( + 1.0, 10.0, dtype=dtypes.float32)) + return array_ops.identity(x) + + x = distribution.call_for_each_tower(model_fn) + reduced_x = distribution.unwrap( + distribution.reduce( + variable_scope.VariableAggregation.MEAN, x, + destinations='/cpu:0'))[0] + + sess.run( + variables.global_variables_initializer(), options=self._run_options) + x_value, reduced_x_value = sess.run( + [x, reduced_x], options=self._run_options) + self.assertTrue(np.array_equal(x_value, reduced_x_value)) + return np.array_equal(x_value, reduced_x_value) + + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) + def testVariableInitialization(self, num_gpus): + if context.num_gpus() < num_gpus: + return + self._run_between_graph_clients( + self._test_variable_initialization, + self._cluster_spec, + num_gpus=num_gpus) + + +class LocalCollectiveAllReduceStrategy(strategy_test_lib.DistributionTestBase, + parameterized.TestCase): + + def testMinimizeLossGraph(self, num_gpus=2): + # Collective ops doesn't support strategy with one device. + if context.num_gpus() < num_gpus: + return + distribution = collective_all_reduce_strategy.CollectiveAllReduceStrategy( + num_gpus_per_worker=num_gpus) + self._test_minimize_loss_graph(distribution) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops.py b/tensorflow/contrib/distribute/python/cross_tower_ops.py index b6037d2133..94be3fb740 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_ops.py +++ b/tensorflow/contrib/distribute/python/cross_tower_ops.py @@ -267,9 +267,9 @@ def _group_value_by_device(per_device_values): This grouping is needed to call the all-reduce library because it expects a list of the following form: - [(grad0_gpu0, v0_gpu0), (grad1_gpu0, v1_gpu0), (grad2_gpu0, v2_gpu0) ... - (grad0_gpu1, v0_gpu1), (grad1_gpu1, v1_gpu1), (grad2_gpu1, v2_gpu1) ... - (grad0_gpu2, v0_gpu2), (grad1_gpu0, v1_gpu2), (grad2_gpu0, v2_gpu2) ... + [[(grad0_gpu0, v0_gpu0), (grad1_gpu0, v1_gpu0), (grad2_gpu0, v2_gpu0) ...], + [(grad0_gpu1, v0_gpu1), (grad1_gpu1, v1_gpu1), (grad2_gpu1, v2_gpu1) ...], + [(grad0_gpu2, v0_gpu2), (grad1_gpu0, v1_gpu2), (grad2_gpu0, v2_gpu2) ...], ... ] @@ -290,7 +290,10 @@ def _group_value_by_device(per_device_values): return grouped -def _ungroup_and_make_mirrored(grouped_reduced, destinations, aggregation): +def _ungroup_and_make_mirrored(grouped_reduced, + destinations, + aggregation, + num_between_graph_workers=1): """Ungroup results from all-reduce and make Mirrored objects. Each all-reduce result will be divided by the number of destinations before @@ -303,6 +306,8 @@ def _ungroup_and_make_mirrored(grouped_reduced, destinations, aggregation): destinations: a list of device strings for returned Mirrored objects. aggregation: Indicates how a variable will be aggregated. Accepted values are @{tf.VariableAggregation.SUM}, @{tf.VariableAggregation.MEAN}. + num_between_graph_workers: number of workers in the between-graph + replication. Returns: a list of Mirrored objects. @@ -311,7 +316,8 @@ def _ungroup_and_make_mirrored(grouped_reduced, destinations, aggregation): for d, per_device_reduced in enumerate(grouped_reduced): for i, (v, _) in enumerate(per_device_reduced): if aggregation == vs.VariableAggregation.MEAN: - index[i][destinations[d]] = v / len(destinations) + index[i][destinations[d]] = v / ( + len(destinations) * num_between_graph_workers) else: index[i][destinations[d]] = v return [value_lib.Mirrored(v) for v in index] @@ -719,6 +725,103 @@ class MultiWorkerAllReduce(AllReduceCrossTowerOps): aggregation) +# TODO(yuefengz): support in-graph collective all-reduce. +class CollectiveAllReduce(CrossTowerOps): + """All-reduce cross tower ops using collective ops. + + In the between-graph replicated training, it will still do all-reduces across + all workers and then put results on the right destinations. + """ + + def __init__(self, + num_workers=1, + num_gpus_per_worker=0, + all_reduce_merge_scope=1, + collective_keys=None): + """Initializes the object. + + Args: + num_workers: number of workers in the between-graph replicated training. + num_gpus_per_worker: number of GPUs per worker. + all_reduce_merge_scope: size of groups into which to partition consecutive + gradients grouped under a common 'allreduce' name scope. This is useful + for some optimization of collective ops. + collective_keys: an optional CollectiveKey object. + """ + self._num_workers = num_workers + self._num_gpus_per_worker = num_gpus_per_worker + self._all_reduce_merge_scope = all_reduce_merge_scope + self._collective_keys = collective_keys or cross_tower_utils.CollectiveKeys( + ) + super(CollectiveAllReduce, self).__init__() + + # TODO(yuefengz, tucker): is index slices supported by collective ops? + def _reduce(self, aggregation, per_device_value, destinations): + all_reduced = self._batch_all_reduce(aggregation, [per_device_value])[0] + if destinations is None or _devices_match(per_device_value, destinations): + return all_reduced + else: + index = {} + for d in get_devices_from(destinations): + # pylint: disable=protected-access + if d in all_reduced._index: + index[d] = all_reduced._index[d] + else: + with ops.device(d): + index[d] = array_ops.identity(all_reduced._index.values()[0]) + return value_lib.Mirrored(index) + + def _batch_reduce(self, aggregation, value_destination_pairs): + return [ + self._reduce(aggregation, t, destinations=v) + for t, v in value_destination_pairs + ] + + def _batch_all_reduce(self, aggregation, per_device_values): + """All-reduce across all workers in a batch.""" + if context.executing_eagerly(): + raise ValueError("Eager mode with collective ops is not supported yet.") + + logging.log_first_n( + logging.INFO, + "Collective All-reduce invoked with batches size = %d, " + "num_workers = %d" % (len(per_device_values), self._num_workers), 10) + + grouped_by_tower = _group_value_by_device(per_device_values) + + grouped_by_var = list(zip(*grouped_by_tower)) + # grouped_by_var is grouped by variables and takes the following format: + # [((grad0_gpu0, v0_gpu0), (grad0_gpu1, v0_gpu1), (grad0_gpu2, v0_gpu2) ..), + # ((grad1_gpu0, v1_gpu0), (grad1_gpu1, v1_gpu1), (grad1_gpu0, v1_gpu2) ..), + # ((grad2_gpu0, v2_gpu0), (grad2_gpu1, v2_gpu1), (grad2_gpu0, v2_gpu2) ..), + # ... + # ] + chunked_gv = [ + grouped_by_var[x:x + self._all_reduce_merge_scope] + for x in range(0, len(grouped_by_var), self._all_reduce_merge_scope) + ] + + reduced_gv_list = [] + for chunk in chunked_gv: + with ops.name_scope("allreduce"): + for grad_and_vars in chunk: + scaled_grads = [g for g, _ in grad_and_vars] + collective_reduced = cross_tower_utils.build_collective_reduce( + scaled_grads, self._num_workers, self._collective_keys, "Add", + "Id") + result = [] + for (_, v), g in zip(grad_and_vars, collective_reduced): + result.append([g, v]) + reduced_gv_list.append(result) + + new_tower_grads = [list(x) for x in zip(*reduced_gv_list)] + return _ungroup_and_make_mirrored( + new_tower_grads, + per_device_values[0].devices, + aggregation, + num_between_graph_workers=self._num_workers) + + _dgx1_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], [0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6]] diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py index 6a780ff60f..21018006c5 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py +++ b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py @@ -21,13 +21,17 @@ from __future__ import print_function import itertools from absl.testing import parameterized +import numpy as np from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib +from tensorflow.contrib.distribute.python import cross_tower_utils from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.contrib.distribute.python import values as value_lib +from tensorflow.core.protobuf import config_pb2 from tensorflow.python.eager import context from tensorflow.python.eager import test +from tensorflow.python.estimator import run_config from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -376,5 +380,173 @@ class MultiWorkerCrossTowerOpsTest(multi_worker_test_base.MultiWorkerTestBase, self._testReductionAndBroadcast(cross_tower_ops, distribution) +class MultiWorkerCollectiveAllReduceTest( + multi_worker_test_base.MultiWorkerTestBase, parameterized.TestCase): + + collective_key_base = 10000 + + @classmethod + def setUpClass(cls): + """Create a local cluster with 2 workers.""" + cls._workers, cls._ps = multi_worker_test_base.create_in_process_cluster( + num_workers=3, num_ps=0) + cls._cluster_spec = { + run_config.TaskType.WORKER: [ + "fake_worker_0", "fake_worker_1", "fake_worker_2" + ] + } + + def setUp(self): + super(MultiWorkerCollectiveAllReduceTest, self).setUp() + # Reusing keys are not supported well. So we have to give a different + # collective key base for different tests. + MultiWorkerCollectiveAllReduceTest.collective_key_base += 100000 + + def _get_test_objects(self, task_type, task_id, num_gpus=0, local_mode=False): + collective_keys = cross_tower_utils.CollectiveKeys( + group_key_start=10 * num_gpus + + MultiWorkerCollectiveAllReduceTest.collective_key_base, + instance_key_start=num_gpus * 100 + + MultiWorkerCollectiveAllReduceTest.collective_key_base, + instance_key_with_id_start=num_gpus * 10000 + + MultiWorkerCollectiveAllReduceTest.collective_key_base) + if local_mode: + collective_all_reduce_ops = cross_tower_ops_lib.CollectiveAllReduce( + 1, num_gpus, collective_keys=collective_keys) + if num_gpus: + devices = ["/device:GPU:%d" % i for i in range(num_gpus)] + else: + devices = ["/device:CPU:0"] + return collective_all_reduce_ops, devices, "local" + else: + collective_all_reduce_ops = cross_tower_ops_lib.CollectiveAllReduce( + 3, num_gpus, collective_keys=collective_keys) + if num_gpus: + devices = [ + "/job:%s/task:%d/device:GPU:%d" % (task_type, task_id, i) + for i in range(num_gpus) + ] + else: + devices = ["/job:%s/task:%d" % (task_type, task_id)] + return collective_all_reduce_ops, devices, self._workers[task_id].target + + def _assert_values_equal(self, left, right, sess): + if isinstance(left, list): + for l, r in zip(left, right): + self._assert_values_equal(l, r, sess) + else: + self.assertEqual(type(left), type(right)) + self.assertEqual(set(left.devices), set(right.devices)) + + run_options = config_pb2.RunOptions() + run_options.experimental.collective_graph_key = 6 + + left_values = np.array( + sess.run(list(left._index.values()), options=run_options)).flatten() + right_values = np.array(right._index.values()).flatten() + self.assertEqual(len(left_values), len(right_values)) + for l, r in zip(left_values, right_values): + self.assertEqual(l, r) + + def _test_reduction(self, task_type, task_id, num_gpus, local_mode=False): + collective_all_reduce, devices, master_target = self._get_test_objects( + task_type, task_id, num_gpus, local_mode=local_mode) + if local_mode: + num_workers = 1 + worker_device = None + else: + num_workers = len(self._workers) + worker_device = "/job:%s/task:%d" % (task_type, task_id) + with ops.Graph().as_default(), \ + ops.device(worker_device), \ + self.test_session(target=master_target) as sess: + # Collective ops doesn't support scalar tensors, so we have to construct + # 1-d tensors. + values = [constant_op.constant([float(d)]) for d in range(len(devices))] + per_device = _make_per_device(values, devices) + mean = np.array([(len(devices) - 1.) / 2.]) + + values_2 = [constant_op.constant([d + 1.0]) for d in range(len(devices))] + per_device_2 = _make_per_device(values_2, devices) + mean_2 = np.array([mean[0] + 1.]) + + destination_mirrored = _fake_mirrored(1., devices) + destination_different = _fake_mirrored(1., _cpu_device) + destination_str = _cpu_device + destination_list = devices + + all_destinations = [ + None, destination_mirrored, destination_different, destination_str, + destination_list + ] + + # test reduce() + for destinations in all_destinations: + self._assert_values_equal( + collective_all_reduce.reduce( + vs.VariableAggregation.MEAN, + per_device, + destinations=destinations), + _fake_mirrored(mean, destinations or per_device), sess) + self._assert_values_equal( + collective_all_reduce.reduce( + vs.VariableAggregation.MEAN, + per_device_2, + destinations=destinations), + _fake_mirrored(mean_2, destinations or per_device), sess) + self._assert_values_equal( + collective_all_reduce.reduce( + vs.VariableAggregation.SUM, + per_device, + destinations=destinations), + _fake_mirrored(mean * len(devices) * num_workers, destinations or + per_device), sess) + self._assert_values_equal( + collective_all_reduce.reduce( + vs.VariableAggregation.SUM, + per_device_2, + destinations=destinations), + _fake_mirrored(mean_2 * len(devices) * num_workers, destinations or + per_device), sess) + + # test batch_reduce() + for d1, d2 in itertools.product(all_destinations, all_destinations): + self._assert_values_equal( + collective_all_reduce.batch_reduce(vs.VariableAggregation.MEAN, + [(per_device, d1), + (per_device_2, d2)]), + [ + _fake_mirrored(mean, d1 or per_device), + _fake_mirrored(mean_2, d2 or per_device_2) + ], sess) + self._assert_values_equal( + collective_all_reduce.batch_reduce(vs.VariableAggregation.SUM, + [(per_device, d1), + (per_device_2, d2)]), + [ + _fake_mirrored(mean * len(devices) * num_workers, d1 or + per_device), + _fake_mirrored(mean_2 * len(devices) * num_workers, d2 or + per_device_2) + ], sess) + + return True + + @combinations.generate( + combinations.combine(mode=["graph"], num_gpus=[0, 1, 2])) + def testReductionDistributed(self, num_gpus): + if context.num_gpus() < num_gpus: + return + self._run_between_graph_clients(self._test_reduction, self._cluster_spec, + num_gpus) + + # Collective ops doesn't support strategy with one device. + def testReductionLocal(self, num_gpus=2): + if context.num_gpus() < num_gpus: + return + self._run_between_graph_clients( + self._test_reduction, self._cluster_spec, num_gpus, local_model=True) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils.py b/tensorflow/contrib/distribute/python/cross_tower_utils.py index 2bb088e704..24cb08fb48 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_utils.py +++ b/tensorflow/contrib/distribute/python/cross_tower_utils.py @@ -19,13 +19,16 @@ from __future__ import division from __future__ import print_function import collections as pycoll +import threading from tensorflow.contrib import nccl from tensorflow.contrib.all_reduce.python import all_reduce from tensorflow.contrib.distribute.python import values as value_lib +from tensorflow.python.framework import device as pydev from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import collective_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops @@ -218,6 +221,146 @@ def split_grads_by_size(threshold_size, device_grads): return small_grads, large_grads +# threading.Lock() cannot be pickled and therefore cannot be a field of +# CollectiveKeys. +_lock = threading.Lock() + + +# TODO(yuefengz): use random key starts to avoid reusing keys? +class CollectiveKeys(object): + """Class that manages collective keys. + + We need to manage three different keys for collective: + + *Group key*: an integer key to identify the set of cooperative devices. + Collective ops work under the same set of devices must using the same group + key. + + *Instance key*: an integer key to identify the set of same counterpart of + tensors on different devices in a device group that need to be all-reduced. + + "Graph key": an integer key that is unique key graph. This is used to support + multiple graphs per client session. It must be non-zero and set in the + `config` argument of each call to `session.run`. + """ + + def __init__(self, + group_key_start=1, + instance_key_start=100, + instance_key_with_id_start=10000): + """Initializes the object. + + Args: + group_key_start: the starting integer of group key. + instance_key_start: the starting integer of instance key. + instance_key_with_id_start: the starting integer of instance key that is + recorded with an id. + """ + self._group_key = group_key_start + self._group_key_table = dict() + + # For instance keys with ids + self._instance_key_id_to_key_table = dict() + self._instance_key_with_id_counter = instance_key_with_id_start + + # For instance keys without ids + self._instance_key_start = instance_key_start + + self._thread_local = threading.local() + + def _get_thread_local_object(self): + # We make instance key without key ids thread local so that it will work + # with MirroredStrategy and distribute coordinator. + if not hasattr(self._thread_local, 'instance_key'): + self._thread_local.instance_key = self._instance_key_start + return self._thread_local + + def get_group_key(self, devices): + """Returns a group key for the set of devices. + + Args: + devices: list of strings naming devices in a collective group. + + Returns: + int key uniquely identifying the set of device names. + """ + parsed = [pydev.DeviceSpec.from_string(d) for d in devices] + # In the between-graph replicated training, different workers need to get + # the same device key. So we remove the task_type and task_id from the + # devices. + # TODO(yuefengz): in the in-graph replicated training, we need to include + # task_type and task_id. + names = sorted(['%s:%d' % (d.device_type, d.device_index) for d in parsed]) + key_id = ','.join(names) + with _lock: + if key_id not in self._group_key_table: + new_key = self._group_key + self._group_key += 1 + self._group_key_table[key_id] = new_key + return self._group_key_table[key_id] + + def get_instance_key(self, key_id=None): + """Returns a new instance key for use in defining a collective op. + + Args: + key_id: optional string. If set, key will be recorded and the same key + will be returned when the same key_id is provided. If not, an increasing + instance key will be returned. + """ + if key_id: + with _lock: + if key_id not in self._instance_key_id_to_key_table: + self._instance_key_with_id_counter += 1 + self._instance_key_id_to_key_table[key_id] = ( + self._instance_key_with_id_counter) + return self._instance_key_id_to_key_table[key_id] + else: + v = self._get_thread_local_object().instance_key + self._get_thread_local_object().instance_key += 1 + return v + + +def build_collective_reduce(input_tensors, + num_workers, + collective_keys, + reduction_op='Add', + unary_op='Id'): + """Build a subgraph that does one full all-reduce, using the collective Op. + + Args: + input_tensors: tensors within a single worker graph that are to be reduced + together; must be one per device. + num_workers: total number of workers with identical independent graphs that + will be doing this same reduction. The reduction will actually include + the corresponding tensors at all these workers. + collective_keys: a CollectiveKeys object. + reduction_op: string naming the reduction op. + unary_op: string naming the unary final op. + + Returns: + An array of final tensors, one per device, computed by the full reduction. + + Raises: + ValueError: There must be at least two tensors over all the workers. + """ + group_size = len(input_tensors) * num_workers + if group_size < 2: + raise ValueError('num_workers * len(input_tensors) must be 2 or greater') + devices = [t.device for t in input_tensors] + num_devices = len(devices) + group_key = collective_keys.get_group_key(devices) + instance_key = collective_keys.get_instance_key() + out_tensors = [] + subdiv_offsets = [0] # TODO(tucker): maybe support non-default subdiv spec + for d in range(num_devices): + with ops.device(devices[d]): + reduce_op = collective_ops.all_reduce( + input_tensors[d], group_size, group_key, instance_key, reduction_op, + unary_op, subdiv_offsets) + out_tensors.append(reduce_op) + return out_tensors + + def sum_grad_and_var_all_reduce(grad_and_vars, num_workers, alg, @@ -253,10 +396,10 @@ def sum_grad_and_var_all_reduce(grad_and_vars, else: raise ValueError('unsupported all_reduce alg: ', alg) - result = [] - for (_, v), g in zip(grad_and_vars, summed_grads): - result.append([g, v]) - return result + result = [] + for (_, v), g in zip(grad_and_vars, summed_grads): + result.append([g, v]) + return result def sum_gradients_all_reduce(dev_prefixes, tower_grads, num_workers, alg, diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index eb2d102012..0c26ae8dbc 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -209,6 +209,75 @@ def _reduce_non_distributed_value(distribution, aggregation, value, return values.Mirrored(value_updates) +def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs): # pylint: disable=g-missing-docstring + # Figure out what collections this variable should be added to. + # We'll add the MirroredVariable to those collections instead. + collections = kwargs.pop("collections", None) + if collections is None: + collections = [ops.GraphKeys.GLOBAL_VARIABLES] + kwargs["collections"] = [] + + # Get synchronization value + synchronization = kwargs.get("synchronization", + variable_scope.VariableSynchronization.ON_WRITE) + if synchronization == variable_scope.VariableSynchronization.NONE: + raise ValueError("`NONE` variable synchronization mode is not " + "supported with `Mirrored` distribution strategy. Please" + " change the `synchronization` for variable: " + + kwargs["name"]) + elif synchronization == variable_scope.VariableSynchronization.ON_READ: + # Variables that are to be synced on read are tower local. + is_tower_local = True + kwargs["trainable"] = False + elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE or + synchronization == variable_scope.VariableSynchronization.AUTO): + # `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`. + is_tower_local = False + else: + raise ValueError("Invalid variable synchronization mode: " + + synchronization + " for variable: " + kwargs["name"]) + + # Get aggregation value + aggregation = kwargs.pop("aggregation", + variable_scope.VariableAggregation.NONE) + if aggregation not in [ + variable_scope.VariableAggregation.NONE, + variable_scope.VariableAggregation.SUM, + variable_scope.VariableAggregation.MEAN + ]: + raise ValueError("Invalid variable aggregation mode: " + aggregation + + " for variable: " + kwargs["name"]) + + # Ignore user-specified caching device, not needed for mirrored variables. + kwargs.pop("caching_device", None) + + # TODO(josh11b,apassos): It would be better if variable initialization + # was never recorded on the tape instead of having to do this manually + # here. + with tape.stop_recording(): + index = real_mirrored_creator(devices, *args, **kwargs) + + if is_tower_local: + result = values.TowerLocalVariable(index, index[devices[0]], aggregation) + else: + result = values.MirroredVariable(index, index[devices[0]], aggregation) + + if not context.executing_eagerly(): + g = ops.get_default_graph() + # If "trainable" is True, next_creator() will add the member variables + # to the TRAINABLE_VARIABLES collection, so we manually remove + # them and replace with the MirroredVariable. We can't set + # "trainable" to False for next_creator() since that causes functions + # like implicit_gradients to skip those variables. + if kwargs.get("trainable", True): + collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) + l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) + for v in index.values(): + l.remove(v) + g.add_to_collections(collections, result) + return result + + class MirroredStrategy(distribute_lib.DistributionStrategy): """Mirrors vars to distribute across multiple devices on a single machine. @@ -243,54 +312,10 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): def _create_variable(self, next_creator, *args, **kwargs): """Create a mirrored variable. See `DistributionStrategy.scope`.""" - # Figure out what collections this variable should be added to. - # We'll add the MirroredVariable to those collections instead. - collections = kwargs.pop("collections", None) - if collections is None: - collections = [ops.GraphKeys.GLOBAL_VARIABLES] - kwargs["collections"] = [] - colocate_with = kwargs.pop("colocate_with", None) devices = self._get_devices_from(colocate_with) - # Get synchronization value - synchronization = kwargs.get( - "synchronization", variable_scope.VariableSynchronization.ON_WRITE) - if synchronization == variable_scope.VariableSynchronization.NONE: - raise ValueError("`NONE` variable synchronization mode is not " - "supported with `Mirrored` distribution strategy. Please" - " change the `synchronization` for variable: " + - kwargs["name"]) - elif synchronization == variable_scope.VariableSynchronization.ON_READ: - # Variables that are to be synced on read are tower local. - is_tower_local = True - kwargs["trainable"] = False - elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE or - synchronization == variable_scope.VariableSynchronization.AUTO): - # `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`. - is_tower_local = False - else: - raise ValueError("Invalid variable synchronization mode: " + - synchronization + " for variable: " + kwargs["name"]) - - # Get aggregation value - aggregation = kwargs.pop("aggregation", - variable_scope.VariableAggregation.NONE) - if aggregation not in [ - variable_scope.VariableAggregation.NONE, - variable_scope.VariableAggregation.SUM, - variable_scope.VariableAggregation.MEAN - ]: - raise ValueError("Invalid variable aggregation mode: " + aggregation + - " for variable: " + kwargs["name"]) - - # Ignore user-specified caching device, not needed for mirrored variables. - kwargs.pop("caching_device", None) - - # TODO(josh11b,apassos): It would be better if variable initialization - # was never recorded on the tape instead of having to do this manually - # here. - with tape.stop_recording(): + def _real_mirrored_creator(devices, *args, **kwargs): # pylint: disable=g-missing-docstring index = {} for i, d in enumerate(devices): with ops.device(d): @@ -314,27 +339,10 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): v = next_creator(*args, **kwargs) assert not isinstance(v, values.DistributedVariable) index[d] = v + return index - if is_tower_local: - result = values.TowerLocalVariable(index, index[devices[0]], - aggregation) - else: - result = values.MirroredVariable(index, index[devices[0]], aggregation) - - if not context.executing_eagerly(): - g = ops.get_default_graph() - # If "trainable" is True, next_creator() will add the member variables - # to the TRAINABLE_VARIABLES collection, so we manually remove - # them and replace with the MirroredVariable. We can't set - # "trainable" to False for next_creator() since that causes functions - # like implicit_gradients to skip those variables. - if kwargs.get("trainable", True): - collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) - l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) - for v in index.values(): - l.remove(v) - g.add_to_collections(collections, result) - return result + return _create_mirrored_variable(devices, _real_mirrored_creator, *args, + **kwargs) def distribute_dataset(self, dataset_fn): return values.PerDeviceDataset( diff --git a/tensorflow/contrib/distribute/python/multi_worker_test_base.py b/tensorflow/contrib/distribute/python/multi_worker_test_base.py index 2063e57178..249de01f08 100644 --- a/tensorflow/contrib/distribute/python/multi_worker_test_base.py +++ b/tensorflow/contrib/distribute/python/multi_worker_test_base.py @@ -38,6 +38,12 @@ def create_in_process_cluster(num_workers, num_ps): worker_config = config_pb2.ConfigProto() worker_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac + # Enable collective ops which has no impact on non-collective ops. + # TODO(yuefengz, tucker): removing this after we move the initialization of + # collective mgr to the session level. + worker_config.experimental.collective_group_leader = ( + '/job:worker/replica:0/task:0') + ps_config = config_pb2.ConfigProto() ps_config.device_count['GPU'] = 0 @@ -86,6 +92,7 @@ class MultiWorkerTestBase(test.TestCase): graph: Optional graph to use during the returned session. config: An optional config_pb2.ConfigProto to use to configure the session. + target: the target of session to connect to. Yields: A Session object that should be used as a context manager to surround -- cgit v1.2.3