aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yuefeng Zhou <yuefengz@google.com>2018-08-02 21:17:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-02 21:24:17 -0700
commit493d7588172bcf476309b3954db342839ca37872 (patch)
treec917e33910990fb616f070ebf35a02473fa6d85c
parent7a78d060c0a50ac77df6bb3c5b2bf6d9a5be1496 (diff)
Add the CollectiveAllReduceStrategy.
PiperOrigin-RevId: 207215423
-rw-r--r--tensorflow/contrib/distribute/python/BUILD56
-rw-r--r--tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py205
-rw-r--r--tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py217
-rw-r--r--tensorflow/contrib/distribute/python/cross_tower_ops.py113
-rw-r--r--tensorflow/contrib/distribute/python/cross_tower_ops_test.py172
-rw-r--r--tensorflow/contrib/distribute/python/cross_tower_utils.py151
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py138
-rw-r--r--tensorflow/contrib/distribute/python/multi_worker_test_base.py7
8 files changed, 984 insertions, 75 deletions
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index f6cc1dcc02..ecfe044b4f 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -134,6 +134,24 @@ py_library(
)
py_library(
+ name = "collective_all_reduce_strategy",
+ srcs = ["collective_all_reduce_strategy.py"],
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ ":cross_tower_ops",
+ ":cross_tower_utils",
+ ":mirrored_strategy",
+ ":values",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:collective_ops",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python/eager:context",
+ ],
+)
+
+py_library(
name = "strategy_test_lib",
testonly = 1,
srcs = ["strategy_test_lib.py"],
@@ -327,6 +345,37 @@ py_library(
],
)
+py_test(
+ name = "collective_all_reduce_strategy_test",
+ srcs = ["collective_all_reduce_strategy_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_pip",
+ ],
+ deps = [
+ ":collective_all_reduce_strategy",
+ ":combinations",
+ ":cross_tower_utils",
+ ":multi_worker_test_base",
+ ":strategy_test_lib",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:init_ops",
+ "//tensorflow/python:layers",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/estimator:run_config",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
py_library(
name = "minimize_loss_test_lib",
testonly = 1,
@@ -497,8 +546,11 @@ py_library(
"//tensorflow/contrib/all_reduce:all_reduce_py",
"//tensorflow/contrib/nccl:nccl_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python:collective_ops",
+ "//tensorflow/python:device",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:gradients",
"//tensorflow/python:math_ops",
],
)
@@ -533,7 +585,9 @@ py_library(
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform",
+ "//tensorflow/python:resource_variable_ops",
"//tensorflow/python:training",
+ "//tensorflow/python:variable_scope",
"//tensorflow/python/eager:context",
"@six_archive//:six",
],
@@ -541,6 +595,7 @@ py_library(
cuda_py_test(
name = "cross_tower_ops_test",
+ size = "large",
srcs = ["cross_tower_ops_test.py"],
additional_deps = [
":combinations",
@@ -555,7 +610,6 @@ cuda_py_test(
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:test",
],
- shard_count = 15,
tags = [
"multi_and_single_gpu",
"no_pip",
diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
new file mode 100644
index 0000000000..9afcaecf78
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
@@ -0,0 +1,205 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Class CollectiveAllReduceStrategy implementing DistributionStrategy."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import json
+import os
+
+from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib
+from tensorflow.contrib.distribute.python import cross_tower_utils
+from tensorflow.contrib.distribute.python import mirrored_strategy
+from tensorflow.contrib.distribute.python import values
+from tensorflow.core.protobuf import cluster_pb2
+from tensorflow.python.eager import context
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import collective_ops
+from tensorflow.python.training import server_lib
+
+
+# TODO(yuefengz): move this function to a common util file.
+def _normalize_cluster_spec(cluster_spec):
+ if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)):
+ return server_lib.ClusterSpec(cluster_spec)
+ elif not isinstance(cluster_spec, server_lib.ClusterSpec):
+ raise ValueError(
+ "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a "
+ "`tf.train.ClusterDef` object")
+ return cluster_spec
+
+
+# TODO(yuefengz): shard the dataset.
+# TODO(yuefengz): support in-graph replication.
+# TODO(yuefengz): it only works with a cluster without a chief node, maybe
+# support chief node?
+class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy):
+ """Distribution strategy that uses collective ops for all-reduce.
+
+ It is similar to the MirroredStrategy but it uses collective ops for
+ reduction. It currently only works for between-graph replication and its
+ reduction will reduce across all workers.
+ """
+
+ def __init__(self,
+ num_gpus_per_worker=0,
+ cluster_spec=None,
+ task_type="worker",
+ task_id=0):
+ """Initializes the object.
+
+ Args:
+ num_gpus_per_worker: number of local GPUs or GPUs per worker.
+ cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the
+ cluster configurations.
+ task_type: the current task type, such as "worker".
+ task_id: the current task id.
+
+ Raises:
+ ValueError: if `task_type` is not in the `cluster_spec`.
+ """
+ self._num_gpus_per_worker = num_gpus_per_worker
+ self._initialize(cluster_spec, task_type, task_id)
+
+ def _initialize(self, cluster_spec, task_type, task_id):
+ if task_type not in ["chief", "worker"]:
+ raise ValueError(
+ "Unrecognized task_type: %r, valid task types are: \"chief\", "
+ "\"worker\"." % task_type)
+ if cluster_spec:
+ self._cluster_spec = _normalize_cluster_spec(cluster_spec)
+ worker_device = "/job:%s/task:%d" % (task_type, task_id)
+ num_workers = len(self._cluster_spec.as_dict().get(task_type, []))
+ if "chief" in self._cluster_spec.as_dict():
+ num_workers += 1
+ if not num_workers:
+ raise ValueError("`task_type` shoud be in `cluster_spec`.")
+
+ # TODO(yuefengz): create a utility to infer chief.
+ if "chief" in self._cluster_spec.as_dict() and task_type == "chief":
+ assert task_id == 0
+ self._is_chief = True
+ else:
+ assert task_type == "worker"
+ self._is_chief = task_id == 0
+ else:
+ self._cluster_spec = None
+ self._is_chief = True
+ worker_device = ""
+ num_workers = 1
+ self._num_workers = num_workers
+
+ if self._num_gpus_per_worker:
+ local_devices = [
+ "%s/device:GPU:%d" % (worker_device, i)
+ for i in range(self._num_gpus_per_worker)
+ ]
+ else:
+ local_devices = [worker_device]
+
+ self._collective_keys = cross_tower_utils.CollectiveKeys()
+ super(CollectiveAllReduceStrategy, self).__init__(
+ devices=local_devices,
+ cross_tower_ops=cross_tower_ops_lib.CollectiveAllReduce(
+ num_workers=num_workers,
+ num_gpus_per_worker=self._num_gpus_per_worker,
+ collective_keys=self._collective_keys))
+
+ # Add a default device so that ops without specified devices will not end up
+ # on other workers.
+ if cluster_spec:
+ self._default_device = "/job:%s/replica:0/task:%d" % (task_type, task_id)
+
+ def _create_variable(self, next_creator, *args, **kwargs):
+ colocate_with = kwargs.pop("colocate_with", None)
+ devices = self._get_devices_from(colocate_with)
+ group_size = len(devices) * self._num_workers
+ group_key = self._collective_keys.get_group_key(self._devices)
+
+ def _real_mirrored_creator(devices, *args, **kwargs):
+ """Creates one MirroredVariable on the current worker."""
+ index = {}
+ collective_instance_key = self._collective_keys.get_instance_key(
+ key_id=kwargs["name"])
+ if "initial_value" not in kwargs:
+ raise ValueError("Initial value must be specified.")
+ initial_value = kwargs["initial_value"]
+ if callable(initial_value):
+ initial_value_fn = initial_value
+ else:
+ initial_value_fn = lambda: initial_value
+
+ for i, d in enumerate(devices):
+ with ops.device(d):
+ if i > 0:
+ # Give replicas meaningful distinct names:
+ var0name = index[devices[0]].name.split(":")[0]
+ # We append a / to variable names created on towers with id > 0 to
+ # ensure that we ignore the name scope and instead use the given
+ # name as the absolute name of the variable.
+ kwargs["name"] = "%s/replica_%d/" % (var0name, i)
+
+ # The initial value fn makes sure variables all initialized to
+ # same values. The first device of the chief worker will send their
+ # variable values to other devices and other workers.
+ def _overridden_initial_value_fn(device=d, index=i): # pylint: disable=g-missing-docstring
+ with ops.device(device):
+ initial_value = initial_value_fn()
+ assert not callable(initial_value)
+ initial_value = ops.convert_to_tensor(initial_value)
+
+ if self._is_chief and index == 0:
+ bcast_send = collective_ops.broadcast_send(
+ initial_value, initial_value.shape, initial_value.dtype,
+ group_size, group_key, collective_instance_key)
+ with ops.control_dependencies([bcast_send]):
+ return array_ops.identity(initial_value)
+ else:
+ return collective_ops.broadcast_recv(
+ initial_value.shape, initial_value.dtype, group_size,
+ group_key, collective_instance_key)
+
+ kwargs["initial_value"] = _overridden_initial_value_fn
+
+ with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
+ v = next_creator(*args, **kwargs)
+
+ assert not isinstance(v, values.DistributedVariable)
+ index[d] = v
+ return index
+
+ # pylint: disable=protected-access
+ return mirrored_strategy._create_mirrored_variable(
+ devices, _real_mirrored_creator, *args, **kwargs)
+
+ def configure(self, session_config=None):
+ # Use TF_CONFIG to get the cluster spec and the current job.
+ if not self._cluster_spec:
+ tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
+ cluster_spec = _normalize_cluster_spec(tf_config.get("cluster", {}))
+
+ task_env = tf_config.get("task", {})
+ if task_env:
+ task_type = task_env.get("type", "worker")
+ task_id = int(task_env.get("index", "0"))
+ else:
+ task_type = "worker"
+ task_id = 0
+
+ if cluster_spec:
+ self._initialize(cluster_spec, task_type, task_id)
diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
new file mode 100644
index 0000000000..b5e54e3b7d
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
@@ -0,0 +1,217 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for CollectiveAllReduceStrategy."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.contrib.distribute.python import collective_all_reduce_strategy
+from tensorflow.contrib.distribute.python import combinations
+from tensorflow.contrib.distribute.python import cross_tower_utils
+from tensorflow.contrib.distribute.python import multi_worker_test_base
+from tensorflow.contrib.distribute.python import strategy_test_lib
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.eager import context
+from tensorflow.python.estimator import run_config
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.layers import core
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradients
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+class DistributedCollectiveAllReduceStrategyTest(
+ multi_worker_test_base.MultiWorkerTestBase, parameterized.TestCase):
+
+ collective_key_base = 0
+
+ @classmethod
+ def setUpClass(cls):
+ """Create a local cluster with 2 workers."""
+ cls._workers, cls._ps = multi_worker_test_base.create_in_process_cluster(
+ num_workers=3, num_ps=0)
+ cls._cluster_spec = {
+ run_config.TaskType.WORKER: [
+ 'fake_worker_0', 'fake_worker_1', 'fake_worker_2'
+ ]
+ }
+
+ def setUp(self):
+ self._run_options = config_pb2.RunOptions()
+ self._run_options.experimental.collective_graph_key = 6
+
+ self._sess_config = config_pb2.ConfigProto()
+ self._sess_config.experimental.collective_group_leader = (
+ '/job:worker/replica:0/task:0')
+
+ # We use a different key_base for each test so that collective keys won't be
+ # reused.
+ # TODO(yuefengz, tucker): enable it to reuse collective keys in different
+ # tests.
+ DistributedCollectiveAllReduceStrategyTest.collective_key_base += 100000
+ super(DistributedCollectiveAllReduceStrategyTest, self).setUp()
+
+ def _get_test_object(self, task_type, task_id, num_gpus=0):
+ distribution = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
+ num_gpus_per_worker=num_gpus,
+ cluster_spec=self._cluster_spec,
+ task_type=task_type,
+ task_id=task_id)
+ collective_keys = cross_tower_utils.CollectiveKeys(
+ group_key_start=10 * num_gpus +
+ DistributedCollectiveAllReduceStrategyTest.collective_key_base,
+ instance_key_start=num_gpus * 100 +
+ DistributedCollectiveAllReduceStrategyTest.collective_key_base,
+ instance_key_with_id_start=num_gpus * 10000 +
+ DistributedCollectiveAllReduceStrategyTest.collective_key_base)
+ distribution._collective_keys = collective_keys
+ distribution._cross_tower_ops._collective_keys = collective_keys
+ return distribution, self._workers[task_id].target
+
+ def _test_minimize_loss_graph(self, task_type, task_id, num_gpus):
+ d, master_target = self._get_test_object(task_type, task_id, num_gpus)
+ with ops.Graph().as_default(), \
+ self.test_session(config=self._sess_config,
+ target=master_target) as sess, \
+ d.scope():
+ l = core.Dense(1, use_bias=False, name='gpu_%d' % d._num_gpus_per_worker)
+
+ def loss_fn(x):
+ y = array_ops.reshape(l(x), []) - constant_op.constant(1.)
+ return y * y
+
+ # TODO(yuefengz, apassos): eager.backprop.implicit_grad is not safe for
+ # multiple graphs (b/111216820).
+ def grad_fn(x):
+ loss = loss_fn(x)
+ var_list = (
+ variables.trainable_variables() + ops.get_collection(
+ ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
+ grads = gradients.gradients(loss, var_list)
+ ret = list(zip(grads, var_list))
+ return ret
+
+ def update(v, g):
+ return v.assign_sub(0.05 * g, use_locking=True)
+
+ one = d.broadcast(constant_op.constant([[1.]]))
+
+ def step():
+ """Perform one optimization step."""
+ # Run forward & backward to get gradients, variables list.
+ g_v = d.call_for_each_tower(grad_fn, one)
+ # Update the variables using the gradients and the update() function.
+ before_list = []
+ after_list = []
+ for g, v in g_v:
+ fetched = d.read_var(v)
+ before_list.append(fetched)
+ with ops.control_dependencies([fetched]):
+ # TODO(yuefengz): support non-Mirrored variable as destinations.
+ g = d.reduce(
+ variable_scope.VariableAggregation.SUM, g, destinations=v)
+ with ops.control_dependencies(d.unwrap(d.update(v, update, g))):
+ after_list.append(d.read_var(v))
+ return before_list, after_list
+
+ before_out, after_out = step()
+
+ if context.num_gpus() < d._num_gpus_per_worker:
+ return True
+
+ sess.run(
+ variables.global_variables_initializer(), options=self._run_options)
+
+ for i in range(10):
+ b, a = sess.run((before_out, after_out), options=self._run_options)
+ if i == 0:
+ before, = b
+ after, = a
+
+ error_before = abs(before - 1)
+ error_after = abs(after - 1)
+ # Error should go down
+ self.assertLess(error_after, error_before)
+ return error_after < error_before
+
+ @combinations.generate(
+ combinations.combine(mode=['graph'], num_gpus=[0, 1, 2]))
+ def testMinimizeLossGraph(self, num_gpus):
+ self._run_between_graph_clients(self._test_minimize_loss_graph,
+ self._cluster_spec, num_gpus)
+
+ def _test_variable_initialization(self, task_type, task_id, num_gpus):
+ distribution, master_target = self._get_test_object(task_type, task_id,
+ num_gpus)
+ with ops.Graph().as_default(), \
+ self.test_session(config=self._sess_config,
+ target=master_target) as sess, \
+ distribution.scope():
+
+ def model_fn():
+ x = variable_scope.get_variable(
+ 'x',
+ shape=(2, 3),
+ initializer=init_ops.random_uniform_initializer(
+ 1.0, 10.0, dtype=dtypes.float32))
+ return array_ops.identity(x)
+
+ x = distribution.call_for_each_tower(model_fn)
+ reduced_x = distribution.unwrap(
+ distribution.reduce(
+ variable_scope.VariableAggregation.MEAN, x,
+ destinations='/cpu:0'))[0]
+
+ sess.run(
+ variables.global_variables_initializer(), options=self._run_options)
+ x_value, reduced_x_value = sess.run(
+ [x, reduced_x], options=self._run_options)
+ self.assertTrue(np.array_equal(x_value, reduced_x_value))
+ return np.array_equal(x_value, reduced_x_value)
+
+ @combinations.generate(
+ combinations.combine(mode=['graph'], num_gpus=[0, 1, 2]))
+ def testVariableInitialization(self, num_gpus):
+ if context.num_gpus() < num_gpus:
+ return
+ self._run_between_graph_clients(
+ self._test_variable_initialization,
+ self._cluster_spec,
+ num_gpus=num_gpus)
+
+
+class LocalCollectiveAllReduceStrategy(strategy_test_lib.DistributionTestBase,
+ parameterized.TestCase):
+
+ def testMinimizeLossGraph(self, num_gpus=2):
+ # Collective ops doesn't support strategy with one device.
+ if context.num_gpus() < num_gpus:
+ return
+ distribution = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
+ num_gpus_per_worker=num_gpus)
+ self._test_minimize_loss_graph(distribution)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops.py b/tensorflow/contrib/distribute/python/cross_tower_ops.py
index b6037d2133..94be3fb740 100644
--- a/tensorflow/contrib/distribute/python/cross_tower_ops.py
+++ b/tensorflow/contrib/distribute/python/cross_tower_ops.py
@@ -267,9 +267,9 @@ def _group_value_by_device(per_device_values):
This grouping is needed to call the all-reduce library because it expects a
list of the following form:
- [(grad0_gpu0, v0_gpu0), (grad1_gpu0, v1_gpu0), (grad2_gpu0, v2_gpu0) ...
- (grad0_gpu1, v0_gpu1), (grad1_gpu1, v1_gpu1), (grad2_gpu1, v2_gpu1) ...
- (grad0_gpu2, v0_gpu2), (grad1_gpu0, v1_gpu2), (grad2_gpu0, v2_gpu2) ...
+ [[(grad0_gpu0, v0_gpu0), (grad1_gpu0, v1_gpu0), (grad2_gpu0, v2_gpu0) ...],
+ [(grad0_gpu1, v0_gpu1), (grad1_gpu1, v1_gpu1), (grad2_gpu1, v2_gpu1) ...],
+ [(grad0_gpu2, v0_gpu2), (grad1_gpu0, v1_gpu2), (grad2_gpu0, v2_gpu2) ...],
...
]
@@ -290,7 +290,10 @@ def _group_value_by_device(per_device_values):
return grouped
-def _ungroup_and_make_mirrored(grouped_reduced, destinations, aggregation):
+def _ungroup_and_make_mirrored(grouped_reduced,
+ destinations,
+ aggregation,
+ num_between_graph_workers=1):
"""Ungroup results from all-reduce and make Mirrored objects.
Each all-reduce result will be divided by the number of destinations before
@@ -303,6 +306,8 @@ def _ungroup_and_make_mirrored(grouped_reduced, destinations, aggregation):
destinations: a list of device strings for returned Mirrored objects.
aggregation: Indicates how a variable will be aggregated. Accepted values
are @{tf.VariableAggregation.SUM}, @{tf.VariableAggregation.MEAN}.
+ num_between_graph_workers: number of workers in the between-graph
+ replication.
Returns:
a list of Mirrored objects.
@@ -311,7 +316,8 @@ def _ungroup_and_make_mirrored(grouped_reduced, destinations, aggregation):
for d, per_device_reduced in enumerate(grouped_reduced):
for i, (v, _) in enumerate(per_device_reduced):
if aggregation == vs.VariableAggregation.MEAN:
- index[i][destinations[d]] = v / len(destinations)
+ index[i][destinations[d]] = v / (
+ len(destinations) * num_between_graph_workers)
else:
index[i][destinations[d]] = v
return [value_lib.Mirrored(v) for v in index]
@@ -719,6 +725,103 @@ class MultiWorkerAllReduce(AllReduceCrossTowerOps):
aggregation)
+# TODO(yuefengz): support in-graph collective all-reduce.
+class CollectiveAllReduce(CrossTowerOps):
+ """All-reduce cross tower ops using collective ops.
+
+ In the between-graph replicated training, it will still do all-reduces across
+ all workers and then put results on the right destinations.
+ """
+
+ def __init__(self,
+ num_workers=1,
+ num_gpus_per_worker=0,
+ all_reduce_merge_scope=1,
+ collective_keys=None):
+ """Initializes the object.
+
+ Args:
+ num_workers: number of workers in the between-graph replicated training.
+ num_gpus_per_worker: number of GPUs per worker.
+ all_reduce_merge_scope: size of groups into which to partition consecutive
+ gradients grouped under a common 'allreduce' name scope. This is useful
+ for some optimization of collective ops.
+ collective_keys: an optional CollectiveKey object.
+ """
+ self._num_workers = num_workers
+ self._num_gpus_per_worker = num_gpus_per_worker
+ self._all_reduce_merge_scope = all_reduce_merge_scope
+ self._collective_keys = collective_keys or cross_tower_utils.CollectiveKeys(
+ )
+ super(CollectiveAllReduce, self).__init__()
+
+ # TODO(yuefengz, tucker): is index slices supported by collective ops?
+ def _reduce(self, aggregation, per_device_value, destinations):
+ all_reduced = self._batch_all_reduce(aggregation, [per_device_value])[0]
+ if destinations is None or _devices_match(per_device_value, destinations):
+ return all_reduced
+ else:
+ index = {}
+ for d in get_devices_from(destinations):
+ # pylint: disable=protected-access
+ if d in all_reduced._index:
+ index[d] = all_reduced._index[d]
+ else:
+ with ops.device(d):
+ index[d] = array_ops.identity(all_reduced._index.values()[0])
+ return value_lib.Mirrored(index)
+
+ def _batch_reduce(self, aggregation, value_destination_pairs):
+ return [
+ self._reduce(aggregation, t, destinations=v)
+ for t, v in value_destination_pairs
+ ]
+
+ def _batch_all_reduce(self, aggregation, per_device_values):
+ """All-reduce across all workers in a batch."""
+ if context.executing_eagerly():
+ raise ValueError("Eager mode with collective ops is not supported yet.")
+
+ logging.log_first_n(
+ logging.INFO,
+ "Collective All-reduce invoked with batches size = %d, "
+ "num_workers = %d" % (len(per_device_values), self._num_workers), 10)
+
+ grouped_by_tower = _group_value_by_device(per_device_values)
+
+ grouped_by_var = list(zip(*grouped_by_tower))
+ # grouped_by_var is grouped by variables and takes the following format:
+ # [((grad0_gpu0, v0_gpu0), (grad0_gpu1, v0_gpu1), (grad0_gpu2, v0_gpu2) ..),
+ # ((grad1_gpu0, v1_gpu0), (grad1_gpu1, v1_gpu1), (grad1_gpu0, v1_gpu2) ..),
+ # ((grad2_gpu0, v2_gpu0), (grad2_gpu1, v2_gpu1), (grad2_gpu0, v2_gpu2) ..),
+ # ...
+ # ]
+ chunked_gv = [
+ grouped_by_var[x:x + self._all_reduce_merge_scope]
+ for x in range(0, len(grouped_by_var), self._all_reduce_merge_scope)
+ ]
+
+ reduced_gv_list = []
+ for chunk in chunked_gv:
+ with ops.name_scope("allreduce"):
+ for grad_and_vars in chunk:
+ scaled_grads = [g for g, _ in grad_and_vars]
+ collective_reduced = cross_tower_utils.build_collective_reduce(
+ scaled_grads, self._num_workers, self._collective_keys, "Add",
+ "Id")
+ result = []
+ for (_, v), g in zip(grad_and_vars, collective_reduced):
+ result.append([g, v])
+ reduced_gv_list.append(result)
+
+ new_tower_grads = [list(x) for x in zip(*reduced_gv_list)]
+ return _ungroup_and_make_mirrored(
+ new_tower_grads,
+ per_device_values[0].devices,
+ aggregation,
+ num_between_graph_workers=self._num_workers)
+
+
_dgx1_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7],
[0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6]]
diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
index 6a780ff60f..21018006c5 100644
--- a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
+++ b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
@@ -21,13 +21,17 @@ from __future__ import print_function
import itertools
from absl.testing import parameterized
+import numpy as np
from tensorflow.contrib.distribute.python import combinations
from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib
+from tensorflow.contrib.distribute.python import cross_tower_utils
from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.contrib.distribute.python import values as value_lib
+from tensorflow.core.protobuf import config_pb2
from tensorflow.python.eager import context
from tensorflow.python.eager import test
+from tensorflow.python.estimator import run_config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -376,5 +380,173 @@ class MultiWorkerCrossTowerOpsTest(multi_worker_test_base.MultiWorkerTestBase,
self._testReductionAndBroadcast(cross_tower_ops, distribution)
+class MultiWorkerCollectiveAllReduceTest(
+ multi_worker_test_base.MultiWorkerTestBase, parameterized.TestCase):
+
+ collective_key_base = 10000
+
+ @classmethod
+ def setUpClass(cls):
+ """Create a local cluster with 2 workers."""
+ cls._workers, cls._ps = multi_worker_test_base.create_in_process_cluster(
+ num_workers=3, num_ps=0)
+ cls._cluster_spec = {
+ run_config.TaskType.WORKER: [
+ "fake_worker_0", "fake_worker_1", "fake_worker_2"
+ ]
+ }
+
+ def setUp(self):
+ super(MultiWorkerCollectiveAllReduceTest, self).setUp()
+ # Reusing keys are not supported well. So we have to give a different
+ # collective key base for different tests.
+ MultiWorkerCollectiveAllReduceTest.collective_key_base += 100000
+
+ def _get_test_objects(self, task_type, task_id, num_gpus=0, local_mode=False):
+ collective_keys = cross_tower_utils.CollectiveKeys(
+ group_key_start=10 * num_gpus +
+ MultiWorkerCollectiveAllReduceTest.collective_key_base,
+ instance_key_start=num_gpus * 100 +
+ MultiWorkerCollectiveAllReduceTest.collective_key_base,
+ instance_key_with_id_start=num_gpus * 10000 +
+ MultiWorkerCollectiveAllReduceTest.collective_key_base)
+ if local_mode:
+ collective_all_reduce_ops = cross_tower_ops_lib.CollectiveAllReduce(
+ 1, num_gpus, collective_keys=collective_keys)
+ if num_gpus:
+ devices = ["/device:GPU:%d" % i for i in range(num_gpus)]
+ else:
+ devices = ["/device:CPU:0"]
+ return collective_all_reduce_ops, devices, "local"
+ else:
+ collective_all_reduce_ops = cross_tower_ops_lib.CollectiveAllReduce(
+ 3, num_gpus, collective_keys=collective_keys)
+ if num_gpus:
+ devices = [
+ "/job:%s/task:%d/device:GPU:%d" % (task_type, task_id, i)
+ for i in range(num_gpus)
+ ]
+ else:
+ devices = ["/job:%s/task:%d" % (task_type, task_id)]
+ return collective_all_reduce_ops, devices, self._workers[task_id].target
+
+ def _assert_values_equal(self, left, right, sess):
+ if isinstance(left, list):
+ for l, r in zip(left, right):
+ self._assert_values_equal(l, r, sess)
+ else:
+ self.assertEqual(type(left), type(right))
+ self.assertEqual(set(left.devices), set(right.devices))
+
+ run_options = config_pb2.RunOptions()
+ run_options.experimental.collective_graph_key = 6
+
+ left_values = np.array(
+ sess.run(list(left._index.values()), options=run_options)).flatten()
+ right_values = np.array(right._index.values()).flatten()
+ self.assertEqual(len(left_values), len(right_values))
+ for l, r in zip(left_values, right_values):
+ self.assertEqual(l, r)
+
+ def _test_reduction(self, task_type, task_id, num_gpus, local_mode=False):
+ collective_all_reduce, devices, master_target = self._get_test_objects(
+ task_type, task_id, num_gpus, local_mode=local_mode)
+ if local_mode:
+ num_workers = 1
+ worker_device = None
+ else:
+ num_workers = len(self._workers)
+ worker_device = "/job:%s/task:%d" % (task_type, task_id)
+ with ops.Graph().as_default(), \
+ ops.device(worker_device), \
+ self.test_session(target=master_target) as sess:
+ # Collective ops doesn't support scalar tensors, so we have to construct
+ # 1-d tensors.
+ values = [constant_op.constant([float(d)]) for d in range(len(devices))]
+ per_device = _make_per_device(values, devices)
+ mean = np.array([(len(devices) - 1.) / 2.])
+
+ values_2 = [constant_op.constant([d + 1.0]) for d in range(len(devices))]
+ per_device_2 = _make_per_device(values_2, devices)
+ mean_2 = np.array([mean[0] + 1.])
+
+ destination_mirrored = _fake_mirrored(1., devices)
+ destination_different = _fake_mirrored(1., _cpu_device)
+ destination_str = _cpu_device
+ destination_list = devices
+
+ all_destinations = [
+ None, destination_mirrored, destination_different, destination_str,
+ destination_list
+ ]
+
+ # test reduce()
+ for destinations in all_destinations:
+ self._assert_values_equal(
+ collective_all_reduce.reduce(
+ vs.VariableAggregation.MEAN,
+ per_device,
+ destinations=destinations),
+ _fake_mirrored(mean, destinations or per_device), sess)
+ self._assert_values_equal(
+ collective_all_reduce.reduce(
+ vs.VariableAggregation.MEAN,
+ per_device_2,
+ destinations=destinations),
+ _fake_mirrored(mean_2, destinations or per_device), sess)
+ self._assert_values_equal(
+ collective_all_reduce.reduce(
+ vs.VariableAggregation.SUM,
+ per_device,
+ destinations=destinations),
+ _fake_mirrored(mean * len(devices) * num_workers, destinations or
+ per_device), sess)
+ self._assert_values_equal(
+ collective_all_reduce.reduce(
+ vs.VariableAggregation.SUM,
+ per_device_2,
+ destinations=destinations),
+ _fake_mirrored(mean_2 * len(devices) * num_workers, destinations or
+ per_device), sess)
+
+ # test batch_reduce()
+ for d1, d2 in itertools.product(all_destinations, all_destinations):
+ self._assert_values_equal(
+ collective_all_reduce.batch_reduce(vs.VariableAggregation.MEAN,
+ [(per_device, d1),
+ (per_device_2, d2)]),
+ [
+ _fake_mirrored(mean, d1 or per_device),
+ _fake_mirrored(mean_2, d2 or per_device_2)
+ ], sess)
+ self._assert_values_equal(
+ collective_all_reduce.batch_reduce(vs.VariableAggregation.SUM,
+ [(per_device, d1),
+ (per_device_2, d2)]),
+ [
+ _fake_mirrored(mean * len(devices) * num_workers, d1 or
+ per_device),
+ _fake_mirrored(mean_2 * len(devices) * num_workers, d2 or
+ per_device_2)
+ ], sess)
+
+ return True
+
+ @combinations.generate(
+ combinations.combine(mode=["graph"], num_gpus=[0, 1, 2]))
+ def testReductionDistributed(self, num_gpus):
+ if context.num_gpus() < num_gpus:
+ return
+ self._run_between_graph_clients(self._test_reduction, self._cluster_spec,
+ num_gpus)
+
+ # Collective ops doesn't support strategy with one device.
+ def testReductionLocal(self, num_gpus=2):
+ if context.num_gpus() < num_gpus:
+ return
+ self._run_between_graph_clients(
+ self._test_reduction, self._cluster_spec, num_gpus, local_model=True)
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils.py b/tensorflow/contrib/distribute/python/cross_tower_utils.py
index 2bb088e704..24cb08fb48 100644
--- a/tensorflow/contrib/distribute/python/cross_tower_utils.py
+++ b/tensorflow/contrib/distribute/python/cross_tower_utils.py
@@ -19,13 +19,16 @@ from __future__ import division
from __future__ import print_function
import collections as pycoll
+import threading
from tensorflow.contrib import nccl
from tensorflow.contrib.all_reduce.python import all_reduce
from tensorflow.contrib.distribute.python import values as value_lib
+from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import collective_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
@@ -218,6 +221,146 @@ def split_grads_by_size(threshold_size, device_grads):
return small_grads, large_grads
+# threading.Lock() cannot be pickled and therefore cannot be a field of
+# CollectiveKeys.
+_lock = threading.Lock()
+
+
+# TODO(yuefengz): use random key starts to avoid reusing keys?
+class CollectiveKeys(object):
+ """Class that manages collective keys.
+
+ We need to manage three different keys for collective:
+
+ *Group key*: an integer key to identify the set of cooperative devices.
+ Collective ops work under the same set of devices must using the same group
+ key.
+
+ *Instance key*: an integer key to identify the set of same counterpart of
+ tensors on different devices in a device group that need to be all-reduced.
+
+ "Graph key": an integer key that is unique key graph. This is used to support
+ multiple graphs per client session. It must be non-zero and set in the
+ `config` argument of each call to `session.run`.
+ """
+
+ def __init__(self,
+ group_key_start=1,
+ instance_key_start=100,
+ instance_key_with_id_start=10000):
+ """Initializes the object.
+
+ Args:
+ group_key_start: the starting integer of group key.
+ instance_key_start: the starting integer of instance key.
+ instance_key_with_id_start: the starting integer of instance key that is
+ recorded with an id.
+ """
+ self._group_key = group_key_start
+ self._group_key_table = dict()
+
+ # For instance keys with ids
+ self._instance_key_id_to_key_table = dict()
+ self._instance_key_with_id_counter = instance_key_with_id_start
+
+ # For instance keys without ids
+ self._instance_key_start = instance_key_start
+
+ self._thread_local = threading.local()
+
+ def _get_thread_local_object(self):
+ # We make instance key without key ids thread local so that it will work
+ # with MirroredStrategy and distribute coordinator.
+ if not hasattr(self._thread_local, 'instance_key'):
+ self._thread_local.instance_key = self._instance_key_start
+ return self._thread_local
+
+ def get_group_key(self, devices):
+ """Returns a group key for the set of devices.
+
+ Args:
+ devices: list of strings naming devices in a collective group.
+
+ Returns:
+ int key uniquely identifying the set of device names.
+ """
+ parsed = [pydev.DeviceSpec.from_string(d) for d in devices]
+ # In the between-graph replicated training, different workers need to get
+ # the same device key. So we remove the task_type and task_id from the
+ # devices.
+ # TODO(yuefengz): in the in-graph replicated training, we need to include
+ # task_type and task_id.
+ names = sorted(['%s:%d' % (d.device_type, d.device_index) for d in parsed])
+ key_id = ','.join(names)
+ with _lock:
+ if key_id not in self._group_key_table:
+ new_key = self._group_key
+ self._group_key += 1
+ self._group_key_table[key_id] = new_key
+ return self._group_key_table[key_id]
+
+ def get_instance_key(self, key_id=None):
+ """Returns a new instance key for use in defining a collective op.
+
+ Args:
+ key_id: optional string. If set, key will be recorded and the same key
+ will be returned when the same key_id is provided. If not, an increasing
+ instance key will be returned.
+ """
+ if key_id:
+ with _lock:
+ if key_id not in self._instance_key_id_to_key_table:
+ self._instance_key_with_id_counter += 1
+ self._instance_key_id_to_key_table[key_id] = (
+ self._instance_key_with_id_counter)
+ return self._instance_key_id_to_key_table[key_id]
+ else:
+ v = self._get_thread_local_object().instance_key
+ self._get_thread_local_object().instance_key += 1
+ return v
+
+
+def build_collective_reduce(input_tensors,
+ num_workers,
+ collective_keys,
+ reduction_op='Add',
+ unary_op='Id'):
+ """Build a subgraph that does one full all-reduce, using the collective Op.
+
+ Args:
+ input_tensors: tensors within a single worker graph that are to be reduced
+ together; must be one per device.
+ num_workers: total number of workers with identical independent graphs that
+ will be doing this same reduction. The reduction will actually include
+ the corresponding tensors at all these workers.
+ collective_keys: a CollectiveKeys object.
+ reduction_op: string naming the reduction op.
+ unary_op: string naming the unary final op.
+
+ Returns:
+ An array of final tensors, one per device, computed by the full reduction.
+
+ Raises:
+ ValueError: There must be at least two tensors over all the workers.
+ """
+ group_size = len(input_tensors) * num_workers
+ if group_size < 2:
+ raise ValueError('num_workers * len(input_tensors) must be 2 or greater')
+ devices = [t.device for t in input_tensors]
+ num_devices = len(devices)
+ group_key = collective_keys.get_group_key(devices)
+ instance_key = collective_keys.get_instance_key()
+ out_tensors = []
+ subdiv_offsets = [0] # TODO(tucker): maybe support non-default subdiv spec
+ for d in range(num_devices):
+ with ops.device(devices[d]):
+ reduce_op = collective_ops.all_reduce(
+ input_tensors[d], group_size, group_key, instance_key, reduction_op,
+ unary_op, subdiv_offsets)
+ out_tensors.append(reduce_op)
+ return out_tensors
+
+
def sum_grad_and_var_all_reduce(grad_and_vars,
num_workers,
alg,
@@ -253,10 +396,10 @@ def sum_grad_and_var_all_reduce(grad_and_vars,
else:
raise ValueError('unsupported all_reduce alg: ', alg)
- result = []
- for (_, v), g in zip(grad_and_vars, summed_grads):
- result.append([g, v])
- return result
+ result = []
+ for (_, v), g in zip(grad_and_vars, summed_grads):
+ result.append([g, v])
+ return result
def sum_gradients_all_reduce(dev_prefixes, tower_grads, num_workers, alg,
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index eb2d102012..0c26ae8dbc 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -209,6 +209,75 @@ def _reduce_non_distributed_value(distribution, aggregation, value,
return values.Mirrored(value_updates)
+def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs): # pylint: disable=g-missing-docstring
+ # Figure out what collections this variable should be added to.
+ # We'll add the MirroredVariable to those collections instead.
+ collections = kwargs.pop("collections", None)
+ if collections is None:
+ collections = [ops.GraphKeys.GLOBAL_VARIABLES]
+ kwargs["collections"] = []
+
+ # Get synchronization value
+ synchronization = kwargs.get("synchronization",
+ variable_scope.VariableSynchronization.ON_WRITE)
+ if synchronization == variable_scope.VariableSynchronization.NONE:
+ raise ValueError("`NONE` variable synchronization mode is not "
+ "supported with `Mirrored` distribution strategy. Please"
+ " change the `synchronization` for variable: " +
+ kwargs["name"])
+ elif synchronization == variable_scope.VariableSynchronization.ON_READ:
+ # Variables that are to be synced on read are tower local.
+ is_tower_local = True
+ kwargs["trainable"] = False
+ elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE or
+ synchronization == variable_scope.VariableSynchronization.AUTO):
+ # `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`.
+ is_tower_local = False
+ else:
+ raise ValueError("Invalid variable synchronization mode: " +
+ synchronization + " for variable: " + kwargs["name"])
+
+ # Get aggregation value
+ aggregation = kwargs.pop("aggregation",
+ variable_scope.VariableAggregation.NONE)
+ if aggregation not in [
+ variable_scope.VariableAggregation.NONE,
+ variable_scope.VariableAggregation.SUM,
+ variable_scope.VariableAggregation.MEAN
+ ]:
+ raise ValueError("Invalid variable aggregation mode: " + aggregation +
+ " for variable: " + kwargs["name"])
+
+ # Ignore user-specified caching device, not needed for mirrored variables.
+ kwargs.pop("caching_device", None)
+
+ # TODO(josh11b,apassos): It would be better if variable initialization
+ # was never recorded on the tape instead of having to do this manually
+ # here.
+ with tape.stop_recording():
+ index = real_mirrored_creator(devices, *args, **kwargs)
+
+ if is_tower_local:
+ result = values.TowerLocalVariable(index, index[devices[0]], aggregation)
+ else:
+ result = values.MirroredVariable(index, index[devices[0]], aggregation)
+
+ if not context.executing_eagerly():
+ g = ops.get_default_graph()
+ # If "trainable" is True, next_creator() will add the member variables
+ # to the TRAINABLE_VARIABLES collection, so we manually remove
+ # them and replace with the MirroredVariable. We can't set
+ # "trainable" to False for next_creator() since that causes functions
+ # like implicit_gradients to skip those variables.
+ if kwargs.get("trainable", True):
+ collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
+ l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
+ for v in index.values():
+ l.remove(v)
+ g.add_to_collections(collections, result)
+ return result
+
+
class MirroredStrategy(distribute_lib.DistributionStrategy):
"""Mirrors vars to distribute across multiple devices on a single machine.
@@ -243,54 +312,10 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
def _create_variable(self, next_creator, *args, **kwargs):
"""Create a mirrored variable. See `DistributionStrategy.scope`."""
- # Figure out what collections this variable should be added to.
- # We'll add the MirroredVariable to those collections instead.
- collections = kwargs.pop("collections", None)
- if collections is None:
- collections = [ops.GraphKeys.GLOBAL_VARIABLES]
- kwargs["collections"] = []
-
colocate_with = kwargs.pop("colocate_with", None)
devices = self._get_devices_from(colocate_with)
- # Get synchronization value
- synchronization = kwargs.get(
- "synchronization", variable_scope.VariableSynchronization.ON_WRITE)
- if synchronization == variable_scope.VariableSynchronization.NONE:
- raise ValueError("`NONE` variable synchronization mode is not "
- "supported with `Mirrored` distribution strategy. Please"
- " change the `synchronization` for variable: " +
- kwargs["name"])
- elif synchronization == variable_scope.VariableSynchronization.ON_READ:
- # Variables that are to be synced on read are tower local.
- is_tower_local = True
- kwargs["trainable"] = False
- elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE or
- synchronization == variable_scope.VariableSynchronization.AUTO):
- # `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`.
- is_tower_local = False
- else:
- raise ValueError("Invalid variable synchronization mode: " +
- synchronization + " for variable: " + kwargs["name"])
-
- # Get aggregation value
- aggregation = kwargs.pop("aggregation",
- variable_scope.VariableAggregation.NONE)
- if aggregation not in [
- variable_scope.VariableAggregation.NONE,
- variable_scope.VariableAggregation.SUM,
- variable_scope.VariableAggregation.MEAN
- ]:
- raise ValueError("Invalid variable aggregation mode: " + aggregation +
- " for variable: " + kwargs["name"])
-
- # Ignore user-specified caching device, not needed for mirrored variables.
- kwargs.pop("caching_device", None)
-
- # TODO(josh11b,apassos): It would be better if variable initialization
- # was never recorded on the tape instead of having to do this manually
- # here.
- with tape.stop_recording():
+ def _real_mirrored_creator(devices, *args, **kwargs): # pylint: disable=g-missing-docstring
index = {}
for i, d in enumerate(devices):
with ops.device(d):
@@ -314,27 +339,10 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
v = next_creator(*args, **kwargs)
assert not isinstance(v, values.DistributedVariable)
index[d] = v
+ return index
- if is_tower_local:
- result = values.TowerLocalVariable(index, index[devices[0]],
- aggregation)
- else:
- result = values.MirroredVariable(index, index[devices[0]], aggregation)
-
- if not context.executing_eagerly():
- g = ops.get_default_graph()
- # If "trainable" is True, next_creator() will add the member variables
- # to the TRAINABLE_VARIABLES collection, so we manually remove
- # them and replace with the MirroredVariable. We can't set
- # "trainable" to False for next_creator() since that causes functions
- # like implicit_gradients to skip those variables.
- if kwargs.get("trainable", True):
- collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
- l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
- for v in index.values():
- l.remove(v)
- g.add_to_collections(collections, result)
- return result
+ return _create_mirrored_variable(devices, _real_mirrored_creator, *args,
+ **kwargs)
def distribute_dataset(self, dataset_fn):
return values.PerDeviceDataset(
diff --git a/tensorflow/contrib/distribute/python/multi_worker_test_base.py b/tensorflow/contrib/distribute/python/multi_worker_test_base.py
index 2063e57178..249de01f08 100644
--- a/tensorflow/contrib/distribute/python/multi_worker_test_base.py
+++ b/tensorflow/contrib/distribute/python/multi_worker_test_base.py
@@ -38,6 +38,12 @@ def create_in_process_cluster(num_workers, num_ps):
worker_config = config_pb2.ConfigProto()
worker_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac
+ # Enable collective ops which has no impact on non-collective ops.
+ # TODO(yuefengz, tucker): removing this after we move the initialization of
+ # collective mgr to the session level.
+ worker_config.experimental.collective_group_leader = (
+ '/job:worker/replica:0/task:0')
+
ps_config = config_pb2.ConfigProto()
ps_config.device_count['GPU'] = 0
@@ -86,6 +92,7 @@ class MultiWorkerTestBase(test.TestCase):
graph: Optional graph to use during the returned session.
config: An optional config_pb2.ConfigProto to use to configure the
session.
+ target: the target of session to connect to.
Yields:
A Session object that should be used as a context manager to surround