aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/all_reduce
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-09-21 12:27:09 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-21 12:33:33 -0700
commit21f37286cb987fbac1b1f1ed414083df44188638 (patch)
treeb5617d15ed8a16a6ad0dc46549c30263a644d6e4 /tensorflow/contrib/all_reduce
parentabf4aa037f21231594c44fc08fb50607de0288b7 (diff)
Add new contrib/all_reduce directory with several implementations
of all-reduce as TensorFlow subgraphs of multiple Ops. PiperOrigin-RevId: 169582099
Diffstat (limited to 'tensorflow/contrib/all_reduce')
-rw-r--r--tensorflow/contrib/all_reduce/BUILD54
-rw-r--r--tensorflow/contrib/all_reduce/python/all_reduce.py858
-rw-r--r--tensorflow/contrib/all_reduce/python/all_reduce_test.py229
3 files changed, 1141 insertions, 0 deletions
diff --git a/tensorflow/contrib/all_reduce/BUILD b/tensorflow/contrib/all_reduce/BUILD
new file mode 100644
index 0000000000..744ae4c1f4
--- /dev/null
+++ b/tensorflow/contrib/all_reduce/BUILD
@@ -0,0 +1,54 @@
+# Description:
+# All-reduce implementations.
+# APIs are subject to change. Eventually to be replaced by equivalent
+# functionality within TensorFlow core.
+
+package(default_visibility = ["//tensorflow:__subpackages__"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "tf_py_test")
+
+py_library(
+ name = "all_reduce",
+ srcs = [
+ "python/all_reduce.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/contrib/nccl:nccl_ops",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_ops",
+ ],
+)
+
+tf_py_test(
+ name = "all_reduce_test",
+ srcs = ["python/all_reduce_test.py"],
+ additional_deps = [
+ ":all_reduce",
+ "//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ "g3doc/sitemap.md",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/all_reduce/python/all_reduce.py b/tensorflow/contrib/all_reduce/python/all_reduce.py
new file mode 100644
index 0000000000..8e7f1791b8
--- /dev/null
+++ b/tensorflow/contrib/all_reduce/python/all_reduce.py
@@ -0,0 +1,858 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities to construct a TF subgraph implementing distributed All-Reduce."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import re
+
+from tensorflow.contrib import nccl
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+
+
+def _flatten_tensors(tensors):
+ """Check tensors for isomorphism and flatten.
+
+ Args:
+ tensors: list of T @{tf.Tensor} which must all have the same shape.
+
+ Returns:
+ tensors: a list of T @{tf.Tensor} which are flattened (1D) views of tensors
+ shape: the original shape of each element of input tensors
+
+ Raises:
+ ValueError: tensors are empty or non-isomorphic.
+ """
+ if not tensors:
+ raise ValueError("tensors cannot be empty")
+ shape = tensors[0].shape
+ for tensor in tensors:
+ shape = shape.merge_with(tensor.shape)
+ if shape.ndims is None:
+ raise ValueError("At least one of the tensors in 'tensors' must have "
+ "statically known rank.")
+ if len(shape) > 1:
+ reshaped = []
+ for t in tensors:
+ with ops.colocate_with(t):
+ reshaped.append(array_ops.reshape(t, [-1]))
+ tensors = reshaped
+ return tensors, shape
+
+
+def _reshape_tensors(tensors, shape):
+ """Reshape tensors flattened by _flatten_tensors.
+
+ Args:
+ tensors: list of T @{tf.Tensor} of identical length 1D tensors.
+ shape: list of integers describing the desired shape. Product of
+ the elements must equal the length of each tensor.
+
+ Returns:
+ list of T @{tf.Tensor} which are the reshaped inputs.
+ """
+ reshaped = []
+ for t in tensors:
+ with ops.colocate_with(t):
+ reshaped.append(array_ops.reshape(t, shape))
+ return reshaped
+
+
+def _padded_split(tensor, pieces):
+ """Like split for 1D tensors but pads-out case where len % pieces != 0.
+
+ Args:
+ tensor: T @{tf.Tensor} that must be 1D.
+ pieces: a positive integer specifying the number of pieces into which
+ tensor should be split.
+
+ Returns:
+ list of T @{tf.Tensor} of length pieces, which hold the values of
+ thin input tensor, in order. The final tensor may
+ be zero-padded on the end to make its size equal to those of all
+ of the other tensors.
+
+ Raises:
+ ValueError: The input tensor is not 1D.
+ """
+ shape = tensor.shape
+ if 1 != len(shape):
+ raise ValueError("input tensor must be 1D")
+ tensor_len = shape[0].value
+ with ops.colocate_with(tensor):
+ if tensor_len % pieces != 0:
+ # pad to an even length
+ chunk_size = 1 + tensor_len // pieces
+ if pieces > tensor_len:
+ # This is an edge case that should not come up in practice,
+ # i.e. a different reduction algorithm would be better,
+ # but we'll make it work just for completeness.
+ pad_len = pieces - tensor_len
+ extended_whole = array_ops.concat(
+ [tensor, array_ops.zeros([pad_len], dtype=tensor.dtype)], 0)
+ parts = array_ops.split(extended_whole, pieces)
+ return parts, pad_len
+ elif (pieces - 1) * chunk_size >= tensor_len:
+ # Another edge case of limited real interest.
+ pad_len = (pieces * chunk_size) % tensor_len
+ extended_whole = array_ops.concat(
+ [tensor, array_ops.zeros([pad_len], dtype=tensor.dtype)], 0)
+ parts = array_ops.split(extended_whole, pieces)
+ return parts, pad_len
+ else:
+ last_chunk_size = tensor_len - (pieces - 1) * chunk_size
+ pad_len = chunk_size - last_chunk_size
+ piece_lens = [chunk_size for _ in range(pieces - 1)] + [last_chunk_size]
+ parts = array_ops.split(tensor, piece_lens)
+ parts[-1] = array_ops.concat(
+ [parts[-1], array_ops.zeros([pad_len], dtype=tensor.dtype)], 0)
+ return parts, pad_len
+ else:
+ return array_ops.split(tensor, pieces), 0
+
+
+def _strip_padding(tensors, pad_len):
+ """Strip the suffix padding added by _padded_split.
+
+ Args:
+ tensors: list of T @{tf.Tensor} of identical length 1D tensors.
+ pad_len: number of elements to be stripped from the end of each tensor.
+
+ Returns:
+ list of T @{tf.Tensor} which are the stripped inputs.
+
+ Raises:
+ ValueError: tensors must be a non-empty list of 1D tensors, and
+ each must be longer than pad_len.
+ """
+ if not tensors:
+ raise ValueError("tensors cannot be empty")
+ shape = tensors[0].shape
+ if len(shape) > 1:
+ raise ValueError("tensors must be 1D")
+ prefix_len = int(shape[0] - pad_len)
+ if prefix_len < 0:
+ raise ValueError("pad_len longer than tensor")
+ stripped = []
+ for t in tensors:
+ with ops.colocate_with(t):
+ stripped.append(array_ops.slice(t, [0], [prefix_len]))
+ return stripped
+
+
+def _ragged_split(tensor, pieces):
+ """Like split for 1D tensors but allows case where len % pieces != 0.
+
+ Args:
+ tensor: T @{tf.Tensor} that must be 1D.
+ pieces: a positive integer specifying the number of pieces into which
+ tensor should be split.
+
+ Returns:
+ list of T @{tf.Tensor} of length pieces, which hold the values of
+ the input tensor, in order. The final tensor may be shorter
+ than the others, which will all be of equal length.
+
+ Raises:
+ ValueError: input tensor must be 1D.
+ """
+ shape = tensor.shape
+ if 1 != len(shape):
+ raise ValueError("input tensor must be 1D")
+ tensor_len = shape[0].value
+ chunk_size = tensor_len // pieces
+ with ops.colocate_with(tensor):
+ if tensor_len != (pieces * chunk_size):
+ # last piece will be short
+ assert pieces > 1
+ last_chunk_size = tensor_len - ((pieces - 1) * chunk_size)
+ assert last_chunk_size > 0
+ piece_lens = [chunk_size for _ in range(pieces - 1)] + [last_chunk_size]
+ return array_ops.split(tensor, piece_lens)
+ else:
+ return array_ops.split(tensor, pieces)
+
+
+def _ring_permutations(num_workers, num_subchunks, gpu_perm):
+ """"Generate an array of device index arrays, one for for each subchunk.
+
+ In the basic ring reduction algorithm there are size(T)/num_devices
+ data chunks and each device process one chunk per tick, i.e. sending
+ one chunk and receiving one chunk. The idea of subchunking is that
+ each device processes num_subchunks smaller data regions per tick,
+ and the ring rank permutation is different for each subchunk index
+ so that a device is potentially sending to and receiving from
+ num_subchunks different other devices at each tick. Where multiple
+ independent data channels exist between devices, this strategy
+ supplies a method of using them in parallel.
+
+ Args:
+ num_workers: number of worker tasks
+ num_subchunks: number of subchunks into which to divide each per-GPU chunk.
+ gpu_perm: an array of integers in [0, num_gpus-1] giving the default
+ ring order of GPUs at each worker. Other permutations will be generated
+ by rotating this array and splicing together per-worker instances.
+
+ Raises:
+ ValueError: the number of subchunks may not exceed the number of GPUs.
+
+ Returns:
+ pred_by_s_d: list of lists that maps (by index) from (subchunk, dev) to
+ preceding device in the permutation for that subchunk. The
+ device index of GPU i at worker j is i + (j * num_gpus).
+ rank_by_s_d: list of lists that maps (by index) from (subchunk, dev) to
+ local rank of device d in the permutation for that subchunk.
+ """
+ num_gpus = len(gpu_perm)
+ devices = num_workers * num_gpus
+ if devices == 0:
+ return [], []
+ if num_subchunks > num_gpus:
+ raise ValueError(
+ "num_subchunks %d must be <= num_gpus %d" % (num_subchunks, num_gpus))
+ rotation_interval = max(1, int(num_gpus / num_subchunks))
+ perms_by_s = []
+ for s in range(0, num_subchunks):
+ full_order = []
+ offset = s * rotation_interval
+ for w in range(0, num_workers):
+ default_order = [(w * num_gpus) + i for i in gpu_perm]
+ dev_order = default_order[offset:] + default_order[:offset]
+ full_order += dev_order
+ perms_by_s.append(full_order)
+ pred_by_s_d = [[-1 for d in range(0, devices)]
+ for s in range(0, num_subchunks)]
+ rank_by_s_d = [[-1 for d in range(0, devices)]
+ for s in range(0, num_subchunks)]
+ for s in range(0, num_subchunks):
+ for d in range(0, devices):
+ for t in range(0, devices):
+ if d == perms_by_s[s][t]:
+ rank_by_s_d[s][d] = t
+ pred_by_s_d[s][d] = perms_by_s[s][(t + devices - 1) % devices]
+ break
+ return (pred_by_s_d, rank_by_s_d)
+
+
+def build_ring_all_reduce(input_tensors, num_workers, num_subchunks,
+ gpu_perm, red_op, un_op=None):
+ """Construct a subgraph performing a ring-style all-reduce of input_tensors.
+
+ Args:
+ input_tensors: a list of T @{tf.Tensor} objects, which must all
+ have the same shape and type.
+ num_workers: number of worker tasks spanned by input_tensors.
+ num_subchunks: number of subchunks each device should process in one tick.
+ gpu_perm: a list of ints giving a ring-wise rank ordering of GPUs at
+ each worker. All workers must have the same number of
+ GPUs with the same rank ordering. If NVLINK is available, this should
+ be a ring order supported by NVLINK edges.
+ red_op: a binary operator for elementwise reduction.
+ un_op: an optional unary operator to apply to fully reduced values.
+
+ Raises:
+ ValueError: empty input_tensors or they don't all have same
+ size.
+
+ Returns:
+ a list of T @{tf.Tensor} identical sum-reductions of input_tensors.
+ """
+ if len(input_tensors) < 2:
+ raise ValueError("input_tensors must be length 2 or longer")
+ input_tensors, shape = _flatten_tensors(input_tensors)
+ devices = [t.device for t in input_tensors]
+ (pred_by_s_d, rank_by_s_d) = _ring_permutations(
+ num_workers, num_subchunks, gpu_perm)
+ chunks_by_dev, pad_len = _build_ring_gather(
+ input_tensors, devices,
+ num_subchunks, pred_by_s_d, rank_by_s_d, red_op)
+ if un_op:
+ chunks_by_dev = _apply_unary_to_chunks(un_op, chunks_by_dev)
+ output_tensors = _build_ring_scatter(pred_by_s_d, rank_by_s_d,
+ chunks_by_dev)
+ if pad_len > 0:
+ output_tensors = _strip_padding(output_tensors, pad_len)
+ if len(shape) > 1:
+ output_tensors = _reshape_tensors(output_tensors, shape)
+ return output_tensors
+
+
+def _build_ring_gather(input_tensors, devices, num_subchunks,
+ pred_by_s_d, rank_by_s_d, red_op):
+ """Construct a subgraph for the first (reduction) pass of ring all-reduce.
+
+ Args:
+ input_tensors: a list of T @{tf.Tensor} 1D input tensors of same
+ shape and type.
+ devices: array of device name strings
+ num_subchunks: number of subchunks each device should process in one tick.
+ pred_by_s_d: as produced by _ring_permutations
+ rank_by_s_d: as produced by _ring_permutations
+ red_op: a binary operator for elementwise reduction
+
+ Raises:
+ ValueError: tensors must all be one dimensional.
+
+ Returns:
+ list of list of T @{tf.Tensor} of (partially) reduced values where
+ exactly num_subchunks chunks at each device are fully reduced.
+ """
+ num_devices = len(input_tensors)
+ if num_devices == 0:
+ return []
+ if num_devices == 1:
+ return input_tensors
+ shape = input_tensors[0].shape
+ if 1 != len(shape):
+ raise ValueError("input tensors must be 1D")
+ num_chunks = num_devices * num_subchunks
+ num_ticks = num_devices - 1
+ # Initialize chunks_by_dev with splits of the input tensors.
+ chunks_by_dev = []
+ split_pad_len = 0
+ for d in range(0, num_devices):
+ with ops.device(devices[d]):
+ splits, split_pad_len = _padded_split(input_tensors[d], num_chunks)
+ chunks_by_dev.append(splits)
+ # Reduction phase
+ for tick in range(0, num_ticks):
+ # One new partial reduction for every chunk
+ new_partial_reductions = [None for _ in range(0, num_chunks)]
+ # Compute reductions with respect to last tick's values
+ for d in range(0, num_devices):
+ with ops.device(devices[d]):
+ for s in range(0, num_subchunks):
+ rank = rank_by_s_d[s][d]
+ seg_index = (rank + num_devices - (2 + tick)) % num_devices
+ pred_dev = pred_by_s_d[s][d]
+ chunk_index = (seg_index * num_subchunks) + s
+ new_partial_reductions[chunk_index] = red_op(
+ chunks_by_dev[pred_dev][chunk_index],
+ chunks_by_dev[d][chunk_index])
+ # Update chunks_by_dev with the new values at the end of the tick.
+ for d in range(0, num_devices):
+ for s in range(0, num_subchunks):
+ rank = rank_by_s_d[s][d]
+ seg_index = (rank + num_devices - (2 + tick)) % num_devices
+ chunk_index = (seg_index * num_subchunks) + s
+ chunks_by_dev[d][chunk_index] = new_partial_reductions[chunk_index]
+ return chunks_by_dev, split_pad_len
+
+
+def _apply_unary_to_chunks(f, chunks_by_dev):
+ """Apply a unary op to each tensor in chunks_by_dev, on same device.
+
+ Args:
+ f: a unary function over T @{tf.Tensor}.
+ chunks_by_dev: list of lists of T @{tf.Tensor}.
+
+ Returns:
+ new list of lists of T @{tf.Tensor} with the same structure as
+ chunks_by_dev containing the derived tensors.
+ """
+ output = []
+ for x in chunks_by_dev:
+ with ops.colocate_with(x[0]):
+ output.append([f(t) for t in x])
+ return output
+
+
+def _build_ring_scatter(pred_by_s_d, rank_by_s_d,
+ chunks_by_dev):
+ """Construct subgraph for second (scatter) pass of ring all-reduce.
+
+ Args:
+ pred_by_s_d: as produced by _ring_permutations
+ rank_by_s_d: as produced by _ring_permutations
+ chunks_by_dev: list of list of T @{tf.Tensor} indexed by ints
+ (device, chunk)
+
+ Raises:
+ ValueError: chunks_by_dev is not well-formed
+
+ Returns:
+ list of T @{tf.Tensor} which are the fully reduced tensors, one
+ at each device corresponding to the outer dimension of chunks_by_dev.
+ """
+ num_devices = len(chunks_by_dev)
+ num_chunks = len(chunks_by_dev[0])
+ if 0 != num_chunks % num_devices:
+ raise ValueError(
+ "Expect number of chunks per device to be divisible by num_devices")
+ num_subchunks = int(num_chunks / num_devices)
+ num_ticks = num_devices - 1
+ for tick in range(0, num_ticks):
+ passed_values = [None for _ in range(0, num_chunks)]
+ for d in range(0, num_devices):
+ with ops.colocate_with(chunks_by_dev[d][0]):
+ for s in range(0, num_subchunks):
+ rank = rank_by_s_d[s][d]
+ seg_index = (rank + num_devices - (1 + tick)) % num_devices
+ pred_dev = pred_by_s_d[s][d]
+ chunk_index = (seg_index * num_subchunks) + s
+ passed_values[chunk_index] = array_ops.identity(
+ chunks_by_dev[pred_dev][chunk_index])
+ for d in range(0, num_devices):
+ for s in range(0, num_subchunks):
+ rank = rank_by_s_d[s][d]
+ seg_index = (rank + num_devices - (1 + tick)) % num_devices
+ chunk_index = (seg_index * num_subchunks) + s
+ chunks_by_dev[d][chunk_index] = passed_values[chunk_index]
+ # Join chunks at each device.
+ output = []
+ for x in chunks_by_dev:
+ with ops.colocate_with(x[0]):
+ output.append(array_ops.concat(x, 0))
+ return output
+
+
+def build_recursive_hd_all_reduce(input_tensors, red_op, un_op=None):
+ """Construct a subgraph for recursive halving-doubling all-reduce.
+
+ The recursive halving-doubling algorithm is described in
+ http://www.mcs.anl.gov/~thakur/papers/ijhpca-coll.pdf
+
+ The concept is to arrange the participating n devices in
+ a linear sequence where devices exchange data pairwise
+ with one other device in each round. During the gather
+ phase there are lg(n) rounds where devices exchange
+ increasingly smaller sub-tensors with another device
+ at increasingly greater distances, until at the top
+ each device has 1/n of the fully reduced values. During the
+ scatter phase each device exchanges its fully reduced
+ sub-tensor (which doubles in length at each round)
+ with one other device at increasingly smaller distances
+ until each device has all of the fully reduced values.
+
+ Note: this preliminary version requires that len(input_tensors) be a
+ power of 2. TODO(tucker): relax this restriction. Also, the
+ number of elements in each tensor must be divisible by 2^h where h
+ is the number of hops in each phase. This will also be relaxed in
+ the future with edge-case specific logic.
+
+ Args:
+ input_tensors: list of T @{tf.Tensor} to be elementwise reduced.
+ red_op: a binary elementwise reduction Op.
+ un_op: an optional unary elementwise Op to apply to reduced values.
+
+ Returns:
+ list of T @{tf.Tensor} which are the fully reduced tensors, one
+ at each device of input_tensors.
+
+ Raises:
+ ValueError: num_devices not a power of 2, or tensor len not divisible
+ by 2 the proper number of times.
+ """
+ devices = [t.device for t in input_tensors]
+ input_tensors, shape = _flatten_tensors(input_tensors)
+ reduced_shards = _build_recursive_hd_gather(input_tensors, devices, red_op)
+ if un_op:
+ reduced_shards = [un_op(t) for t in reduced_shards]
+ output_tensors = _build_recursive_hd_scatter(reduced_shards, devices)
+ if len(shape) > 1:
+ output_tensors = _reshape_tensors(output_tensors, shape)
+ return output_tensors
+
+
+def _build_recursive_hd_gather(input_tensors, devices, red_op):
+ """Construct the gather phase of recursive halving-doubling all-reduce.
+
+ Args:
+ input_tensors: list of T @{tf.Tensor} to be elementwise reduced.
+ devices: a list of strings naming the devices hosting input_tensors,
+ which will also be used to host the (partial) reduction values.
+ red_op: a binary elementwise reduction Op.
+
+ Returns:
+ list of T @{tf.Tensor} which are the fully reduced tensor shards.
+
+ Raises:
+ ValueError: num_devices not a power of 2, or tensor len not divisible
+ by 2 the proper number of times.
+ """
+ num_devices = len(devices)
+ num_hops = int(math.log(num_devices, 2))
+ if num_devices != (2 ** num_hops):
+ raise ValueError("num_devices must be a power of 2")
+ chunks = input_tensors
+ for h in range(0, num_hops):
+ span = 2 ** h
+ group_size = span * 2
+ new_chunks = [[] for _ in devices]
+ for d in range(0, num_devices):
+ if (d % group_size) >= (group_size / 2):
+ # skip right half of a pair
+ continue
+ left_dev = devices[d]
+ right_dev = devices[d + span]
+ left_split = array_ops.split(chunks[d], 2)
+ right_split = array_ops.split(chunks[d+span], 2)
+ with ops.device(left_dev):
+ new_chunks[d] = red_op(left_split[0], right_split[0])
+ with ops.device(right_dev):
+ new_chunks[d + span] = red_op(left_split[1], right_split[1])
+ chunks = new_chunks
+ return chunks
+
+
+def _build_recursive_hd_scatter(input_tensors, devices):
+ """Construct the scatter phase of recursive halving-doublng all-reduce.
+
+ Args:
+ input_tensors: list of T @{tf.Tensor} that are fully-reduced shards.
+ devices: a list of strings naming the devices on which the reconstituted
+ full tensors should be placed.
+
+ Returns:
+ list of T @{tf.Tensor} which are the fully reduced tensors.
+ """
+ num_devices = len(devices)
+ num_hops = int(math.log(num_devices, 2))
+ assert num_devices == (2 ** num_hops), "num_devices must be a power of 2"
+ chunks = input_tensors
+ for h in reversed(range(0, num_hops)):
+ span = 2 ** h
+ group_size = span * 2
+ new_chunks = [[] for _ in devices]
+ for d in range(0, num_devices):
+ if (d % group_size) >= (group_size / 2):
+ # skip right half of a pair
+ continue
+ left_idx = d
+ right_idx = d + span
+ left_dev = devices[left_idx]
+ right_dev = devices[right_idx]
+ with ops.device(left_dev):
+ new_chunks[left_idx] = array_ops.concat([chunks[left_idx],
+ chunks[right_idx]], 0)
+ with ops.device(right_dev):
+ new_chunks[right_idx] = array_ops.concat([chunks[left_idx],
+ chunks[right_idx]], 0)
+ chunks = new_chunks
+ return chunks
+
+
+def build_shuffle_all_reduce(input_tensors, gather_devices, red_op, un_op=None):
+ """Construct a subgraph for shuffle all-reduce.
+
+ Shuffle reduce is essentially the algorithm implemented when using
+ parameter servers. Suppose tensor length is n, there are d devices
+ and g gather shards. Each device sends a n/g length sub-tensor to
+ each gather shard. The gather shards perform a reduction across d
+ fragments, then broadcast the result back to each device. The
+ devices then join the g fully reduced fragments they receive from
+ the shards. The gather shards could perform d-1 pairwise
+ reductions, or one d-way reduction. The first is better where
+ reduction Op time is low compared to transmission time, the second
+ better in the other case.
+
+ Args:
+ input_tensors: list of T @(tf.Tensor} values to be reduced.
+ gather_devices: list of names of devices on which reduction shards
+ should be placed.
+ red_op: an n-array elementwise reduction Op
+ un_op: optional elementwise unary Op to be applied to fully-reduced values.
+
+ Returns:
+ list of T @{tf.Tensor} which are the fully reduced tensors.
+ """
+ input_tensors, shape = _flatten_tensors(input_tensors)
+ dst_devices = [t.device for t in input_tensors]
+ reduced_shards = _build_shuffle_gather(input_tensors, gather_devices,
+ red_op, un_op)
+ output_tensors = _build_shuffle_scatter(reduced_shards, dst_devices)
+ if len(shape) > 1:
+ output_tensors = _reshape_tensors(output_tensors, shape)
+ return output_tensors
+
+
+def _build_shuffle_gather(input_tensors, gather_devices, red_op, un_op=None):
+ """Construct the gather (concentrate and reduce) phase of shuffle all-reduce.
+
+ Args:
+ input_tensors: list of T @(tf.Tensor} values to be reduced.
+ gather_devices: list of names of devices on which reduction shards
+ should be placed.
+ red_op: the binary reduction Op
+ un_op: optional elementwise unary Op to be applied to fully-reduced values.
+
+ Returns:
+ list of T @{tf.Tensor} which are the fully reduced shards.
+
+ Raises:
+ ValueError: inputs not well-formed.
+ """
+ num_source_devices = len(input_tensors)
+ num_gather_devices = len(gather_devices)
+ shape = input_tensors[0].shape
+ if len(shape) != 1:
+ raise ValueError("input_tensors must be 1D")
+ shards_by_source = []
+ for d in range(0, num_source_devices):
+ with ops.colocate_with(input_tensors[d]):
+ shards_by_source.append(
+ _ragged_split(input_tensors[d], num_gather_devices))
+ reduced_shards = []
+ for d in range(0, num_gather_devices):
+ with ops.device(gather_devices[d]):
+ values = [s[d] for s in shards_by_source]
+ red_shard = red_op(values)
+ if un_op:
+ red_shard = un_op(red_shard)
+ reduced_shards.append(red_shard)
+ return reduced_shards
+
+
+def _build_shuffle_scatter(reduced_shards, dst_devices):
+ """Build the scatter phase of shuffle all-reduce.
+
+ Args:
+ reduced_shards: list of T @(tf.Tensor} fully reduced shards
+ dst_devices: list of names of devices at which the fully-reduced value
+ should be reconstituted.
+
+ Returns:
+ list of T @{tf.Tensor} scattered tensors.
+ """
+ num_devices = len(dst_devices)
+ out_tensors = []
+ for d in range(0, num_devices):
+ with ops.device(dst_devices[d]):
+ out_tensors.append(array_ops.concat(reduced_shards, 0))
+ return out_tensors
+
+
+def _split_by_task(devices, values):
+ """Partition devices and values by common task.
+
+ Args:
+ devices: list of device name strings
+ values: list of T @{tf.tensor} of same length as devices.
+
+ Returns:
+ (per_task_devices, per_task_values) where both values are
+ lists of lists with isomorphic structure: the outer list is
+ indexed by task, and the inner list has length of the number
+ of values belonging to that task. per_task_devices contains
+ the specific devices to which the values are local, and
+ per_task_values contains the corresponding values.
+
+ Raises:
+ ValueError: devices must be same length as values.
+ """
+ num_devices = len(devices)
+ if num_devices != len(values):
+ raise ValueError("len(devices) must equal len(values)")
+ pattern = re.compile(r"/task:(\d+)/")
+ per_task_devices = []
+ per_task_values = []
+ for d in range(num_devices):
+ m = pattern.search(devices[d])
+ if m:
+ index = int(m.group(1))
+ while index >= len(per_task_devices):
+ per_task_devices.append([])
+ per_task_values.append([])
+ per_task_devices[index].append(devices[d])
+ per_task_values[index].append(values[d])
+ else:
+ assert False, "failed to parse device %s" % devices[d]
+ return (per_task_devices, per_task_values)
+
+
+def build_nccl_all_reduce(input_tensors, red_op, un_op=None):
+ """Build a subgraph that does one full all-reduce, using NCCL.
+
+ Args:
+ input_tensors: list of T @{tf.Tensor} of same-shape and type values to
+ be reduced.
+ red_op: binary elementwise reduction operator. Must be one of
+ {tf.add}
+ un_op: optional unary elementwise Op to apply to fully-reduce values.
+
+ Returns:
+ list of T @{tf.Tensor} of reduced values.
+
+ Raises:
+ ValueError: red_op not supported.
+ """
+ if red_op == math_ops.add:
+ output_tensors = nccl.all_sum(input_tensors)
+ else:
+ raise ValueError("red_op not supported by NCCL all-reduce: ", red_op)
+ if un_op:
+ un_op_wrapped = []
+ for t in output_tensors:
+ with ops.colocate_with(t):
+ un_op_wrapped.append(un_op(t))
+ output_tensors = un_op_wrapped
+ return output_tensors
+
+
+def _build_nccl_hybrid(input_tensors, red_op, upper_level_f):
+ """Construct a subgraph for NCCL hybrid all-reduce.
+
+ Args:
+ input_tensors: list of T @{tf.Tensor} of same-shape and type values to
+ be reduced.
+ red_op: binary elementwise reduction operator.
+ upper_level_f: function for reducing one value per worker, across
+ workers.
+
+ Returns:
+ list of T @{tf.Tensor} of reduced values.
+
+ Raises:
+ ValueError: inputs not well-formed.
+ """
+ input_tensors, shape = _flatten_tensors(input_tensors)
+ devices = [t.device for t in input_tensors]
+ per_worker_devices, per_worker_values = _split_by_task(devices, input_tensors)
+ num_workers = len(per_worker_devices)
+ up_values = [None for w in range(0, num_workers)]
+ up_devices = up_values[:]
+ down_values = up_values[:]
+ # First stage: reduce within each worker using NCCL
+ for w in range(0, num_workers):
+ worker_values = build_nccl_all_reduce(per_worker_values[w], red_op)
+ # NOTE: these reductions will not run to completion unless
+ # every output value is used. Since we only need one, we
+ # need to put control dependencies on the rest.
+ with ops.control_dependencies(worker_values):
+ with ops.device(worker_values[0].device):
+ up_values[w] = array_ops.identity(worker_values[0])
+ up_devices[w] = per_worker_devices[w][0]
+ # Second stage: Apply upper_level_f to reduce across first device at
+ # each worker
+ level_2_output = upper_level_f(up_values)
+ # Third stage: propagate within each worker using NCCL Broadcast
+ for w in range(0, num_workers):
+ dst_devices = per_worker_devices[w][1:]
+ send_op, dst_tensors = nccl.broadcast(level_2_output[w], dst_devices)
+ # NOTE: need control dependency to ensure send_op executes
+ with ops.control_dependencies([send_op]):
+ with ops.device(per_worker_devices[w][0]):
+ dst_tensors.insert(0, array_ops.identity(level_2_output[w]))
+ down_values[w] = dst_tensors
+ output_tensors = [v for sublist in down_values for v in sublist]
+ if len(shape) > 1:
+ output_tensors = _reshape_tensors(output_tensors, shape)
+ return output_tensors
+
+
+def _reduce_non_singleton(input_tensors, red_f, un_op):
+ """If input_tenors has more than one element apply red_f, else apply un_op."""
+ if len(input_tensors) > 1:
+ return red_f(input_tensors)
+ else:
+ output_tensors = []
+ for t in input_tensors:
+ with ops.colocate_with(t):
+ output_tensors.append(un_op(t))
+ return output_tensors
+
+
+def build_nccl_then_ring(input_tensors, subdiv, red_op, un_op=None):
+ """Construct hybrid of NCCL within workers, Ring across workers."""
+ def upper_builder(y):
+ return build_ring_all_reduce(y, len(y), subdiv, [0], red_op, un_op)
+ def upper_level_f(x):
+ return _reduce_non_singleton(x, upper_builder, un_op)
+ return _build_nccl_hybrid(input_tensors, red_op, upper_level_f)
+
+
+def build_nccl_then_recursive_hd(input_tensors, red_op, un_op=None):
+ """Construct hybrid of NCCL within workers, Recursive-HD across workers."""
+ upper_level_f = lambda x: build_recursive_hd_all_reduce(x, red_op, un_op)
+ return _build_nccl_hybrid(input_tensors, red_op, upper_level_f)
+
+
+def build_nccl_then_shuffle(input_tensors, gather_devices, nccl_red_op,
+ shuffle_red_op, un_op=None):
+ """Construct hybrid of NCCL within workers, Shuffle across workers."""
+ upper_level_f = lambda x: build_shuffle_all_reduce(x, gather_devices,
+ shuffle_red_op, un_op)
+ return _build_nccl_hybrid(input_tensors, nccl_red_op, upper_level_f)
+
+
+def _build_shuffle_hybrid(input_tensors, gather_devices, red_op, upper_level_f):
+ """Construct a subgraph for Shuffle hybrid all-reduce.
+
+ Args:
+ input_tensors: list of T @{tf.Tensor} of same-shape and type values to
+ be reduced.
+ gather_devices: list of device names on which to host gather shards.
+ red_op: binary elementwise reduction operator.
+ upper_level_f: function for reducing one value per worker, across
+ workers.
+
+ Returns:
+ list of T @{tf.Tensor} of reduced values.
+
+ Raises:
+ ValueError: inputs not well-formed.
+ """
+ input_tensors, shape = _flatten_tensors(input_tensors)
+ # First stage, reduce across each worker using gather_devices.
+ devices = [t.device for t in input_tensors]
+ per_worker_devices, per_worker_values = _split_by_task(devices, input_tensors)
+ num_workers = len(per_worker_devices)
+ up_values = []
+ if len(gather_devices) != num_workers:
+ raise ValueError("For shuffle hybrid, gather_devices must contain one "
+ "device per worker. ")
+ for w in range(0, num_workers):
+ reduced_shards = _build_shuffle_gather(
+ per_worker_values[w], [gather_devices[w]], red_op)
+ up_values.append(reduced_shards[0])
+ # Second stage, apply upper_level_f.
+ level_2_output = upper_level_f(up_values)
+ # Third stage, apply shuffle scatter at each worker.
+ output_tensors = []
+ for w in range(0, num_workers):
+ output_tensors += _build_shuffle_scatter(
+ [level_2_output[w]], per_worker_devices[w])
+ if len(shape) > 1:
+ output_tensors = _reshape_tensors(output_tensors, shape)
+ return output_tensors
+
+
+def build_shuffle_then_ring(input_tensors, gather_devices, subdiv,
+ red_n_op, red_op, un_op):
+ """Construct hybrid of Shuffle within workers, Ring across workers."""
+ def upper_builder(tensors):
+ return build_ring_all_reduce(tensors, len(tensors), subdiv, [0],
+ red_op, un_op)
+ def upper_level_f(tensors):
+ return _reduce_non_singleton(tensors, upper_builder, un_op)
+ return _build_shuffle_hybrid(
+ input_tensors, gather_devices, red_n_op, upper_level_f)
+
+
+def build_shuffle_then_shuffle(input_tensors, first_gather_devices,
+ second_gather_devices, red_op, un_op=None):
+ """Construct hybrid of Shuffle within workers, Shuffle across workers."""
+ def upper_builder(tensors):
+ return build_shuffle_all_reduce(tensors, second_gather_devices,
+ red_op, un_op)
+ def upper_level_f(tensors):
+ return _reduce_non_singleton(tensors, upper_builder, un_op)
+ return _build_shuffle_hybrid(
+ input_tensors, first_gather_devices, red_op, upper_level_f)
diff --git a/tensorflow/contrib/all_reduce/python/all_reduce_test.py b/tensorflow/contrib/all_reduce/python/all_reduce_test.py
new file mode 100644
index 0000000000..0802b27369
--- /dev/null
+++ b/tensorflow/contrib/all_reduce/python/all_reduce_test.py
@@ -0,0 +1,229 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.contrib.all_reduce.python..all_reduce."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+import numpy as np
+
+from tensorflow.contrib.all_reduce.python import all_reduce as ar
+from tensorflow.core.framework import types_pb2
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging
+
+
+class AllReduceTest(test_util.TensorFlowTestCase):
+
+ def testRingPermutations(self):
+ # 0 devices
+ pred_by_c_d, rank_by_c_d = ar._ring_permutations(1, 0, [])
+ self.assertEqual(pred_by_c_d, [])
+ self.assertEqual(rank_by_c_d, [])
+ # 1 worker, 1 subchunk cases
+ pred_by_c_d, rank_by_c_d = ar._ring_permutations(1, 1, [0])
+ self.assertEqual(pred_by_c_d, [[0]])
+ self.assertEqual(rank_by_c_d, [[0]])
+ pred_by_c_d, rank_by_c_d = ar._ring_permutations(1, 1, [0, 1, 2])
+ self.assertEqual(pred_by_c_d, [[2, 0, 1]])
+ self.assertEqual(rank_by_c_d, [[0, 1, 2]])
+ # multiple workers, 1 subchunk cases
+ pred_by_c_d, rank_by_c_d = ar._ring_permutations(2, 1, [0, 1, 2])
+ self.assertEqual(pred_by_c_d, [[5, 0, 1, 2, 3, 4]])
+ self.assertEqual(rank_by_c_d, [[0, 1, 2, 3, 4, 5]])
+ pred_by_c_d, rank_by_c_d = ar._ring_permutations(3, 1, [0, 1, 2])
+ self.assertEqual(pred_by_c_d, [[8, 0, 1, 2, 3, 4, 5, 6, 7]])
+ self.assertEqual(rank_by_c_d, [[0, 1, 2, 3, 4, 5, 6, 7, 8]])
+ pred_by_c_d, rank_by_c_d = ar._ring_permutations(2, 1, [2, 1, 0])
+ self.assertEqual(pred_by_c_d, [[1, 2, 3, 4, 5, 0]])
+ self.assertEqual(rank_by_c_d, [[2, 1, 0, 5, 4, 3]])
+ # 1 worker, multiple subchunk cases
+ pred_by_c_d, rank_by_c_d = ar._ring_permutations(1, 2, [0, 1, 2, 3])
+ self.assertEqual(pred_by_c_d, [[3, 0, 1, 2], [3, 0, 1, 2]])
+ self.assertEqual(rank_by_c_d, [[0, 1, 2, 3], [2, 3, 0, 1]])
+ pred_by_c_d, rank_by_c_d = ar._ring_permutations(1, 4, [0, 1, 2, 3])
+ self.assertEqual(pred_by_c_d, [[3, 0, 1, 2], [3, 0, 1, 2],
+ [3, 0, 1, 2], [3, 0, 1, 2]])
+ self.assertEqual(rank_by_c_d, [[0, 1, 2, 3], [3, 0, 1, 2],
+ [2, 3, 0, 1], [1, 2, 3, 0]])
+ # multiple worker, multiple subchunk cases
+ pred_by_c_d, rank_by_c_d = ar._ring_permutations(2, 2, [0, 1, 2, 3])
+ self.assertEqual(pred_by_c_d, [[7, 0, 1, 2, 3, 4, 5, 6],
+ [3, 0, 5, 2, 7, 4, 1, 6]])
+ self.assertEqual(rank_by_c_d, [[0, 1, 2, 3, 4, 5, 6, 7],
+ [2, 3, 0, 1, 6, 7, 4, 5]])
+ pred_by_c_d, rank_by_c_d = ar._ring_permutations(2, 2, [0, 3, 2, 1])
+ self.assertEqual(pred_by_c_d, [[5, 2, 3, 0, 1, 6, 7, 4],
+ [1, 2, 7, 0, 5, 6, 3, 4]])
+ self.assertEqual(rank_by_c_d, [[0, 3, 2, 1, 4, 7, 6, 5],
+ [2, 1, 0, 3, 6, 5, 4, 7]])
+
+ def _buildInput(self, num_workers, num_gpus):
+ t8 = constant_op.constant(
+ [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
+ types_pb2.DT_FLOAT)
+ input_tensors = []
+ device_names = []
+ for w in range(0, num_workers):
+ for d in range(0, num_gpus):
+ dn = "/replica:0/task:%d/device:GPU:%d" % (w, d % num_gpus)
+ device_names.append(dn)
+ with ops.device(dn):
+ input_tensors.append(array_ops.identity(t8))
+ return input_tensors, device_names
+
+ def testBuildRingGatherPassStructure(self):
+ # 1 worker, 1 device
+ input_tensors, device_names = self._buildInput(1, 1)
+ pred_by_c_d, rank_by_c_d = ar._ring_permutations(1, 1, [0])
+ output_tensors = ar._build_ring_gather(input_tensors, device_names, 1,
+ pred_by_c_d, rank_by_c_d,
+ math_ops.add)
+ self.assertEqual(output_tensors, input_tensors)
+ # 1 worker, 4 devices, 2 subchunks
+ input_tensors, device_names = self._buildInput(1, 4)
+ pred_by_c_d, rank_by_c_d = ar._ring_permutations(1, 2, [0, 1, 2, 3])
+ output_tensors, pad_len = ar._build_ring_gather(
+ input_tensors, device_names, 2, pred_by_c_d, rank_by_c_d, math_ops.add)
+ self.assertEqual(0, pad_len)
+ # same number outputs as inputs
+ self.assertEqual(len(output_tensors), len(input_tensors))
+ num_chunks = 2 * len(input_tensors)
+ tlen = input_tensors[0].shape[0].value
+ for otl in output_tensors:
+ self.assertEqual(len(otl), num_chunks)
+ for ot in otl:
+ self.assertEqual(ot.shape, [tlen/num_chunks])
+
+ def _buildInitialVars(self, shape, dev_list):
+ values = []
+ num_devices = len(dev_list)
+ dim = np.prod(shape)
+ for d in range(0, num_devices):
+ with ops.device(dev_list[d]):
+ npt = np.zeros(shape).astype(np.float32)
+ alias = np.frombuffer(npt.data, dtype=np.float32)
+ for i in range(0, dim):
+ alias[i] = i + 0.01 * d
+ var = state_ops.variable_op(shape, types_pb2.DT_FLOAT)
+ state_ops.init_variable(var, npt).op.run()
+ values.append(var)
+ return values
+
+ # pylint: disable=g-long-lambda
+
+ def _buildRing(self, num_workers, num_gpus, subdiv):
+ gpu_perm = range(0, num_gpus)
+ return lambda x, un_op: ar.build_ring_all_reduce(
+ x, num_workers, subdiv, gpu_perm, math_ops.add, un_op)
+
+ def _testAllReduce(self, num_workers, num_gpus, shape, build_f):
+ # Use local CPU as device for all inputs.
+ num_devices = num_workers * num_gpus
+ dev_list = ["/replica:0/task:0/device:CPU:0"
+ for _ in range(num_devices)]
+ with self.test_session():
+ input_tensors = self._buildInitialVars(shape, dev_list)
+ un_op = lambda x: math_ops.div(
+ x, constant_op.constant(num_devices, dtype=types_pb2.DT_FLOAT))
+ simple_sum = math_ops.add_n(input_tensors)
+ simple_sum.op.run()
+ output_tensors = build_f(input_tensors, un_op)
+ sum_reduced = math_ops.add_n(output_tensors)
+ sum_reduced.op.run()
+ self.assertAllClose(sum_reduced.eval(), simple_sum.eval())
+
+ def _testRingAllReduce(self, num_workers, num_gpus, shape, subdiv):
+ start_time = time.time()
+ build_f = self._buildRing(num_workers, num_gpus, subdiv)
+ self._testAllReduce(num_workers, num_gpus, shape, build_f)
+ elapsed = time.time() - start_time
+ tf_logging.info("RingAllReduce num_workers=%d num_gpus=%d shape=%s "
+ "subdiv=%d elapsed=%f" %
+ (num_workers, num_gpus, shape, subdiv, elapsed))
+
+ def testRingAllReduce(self):
+ self._testRingAllReduce(1, 2, [8], 1)
+ self._testRingAllReduce(1, 2, [4, 4], 1)
+ self._testRingAllReduce(6, 1, [8], 1)
+ self._testRingAllReduce(1, 8, [32], 1)
+ self._testRingAllReduce(1, 8, [120], 1)
+ self._testRingAllReduce(2, 8, [7, 13], 1)
+ self._testRingAllReduce(2, 8, [8, 8], 2)
+ self._testRingAllReduce(2, 8, [8, 8], 4)
+ # TODO(tucker): The following test is surprisingly slow.
+ # Diagnose and fix before re-enabling.
+ # self._testRingAllReduce(4, 8, [8, 8, 2], 4)
+
+ def _buildShuffle(self, num_workers, num_gpus, num_shards):
+ # Use local CPU for all shuffle shards
+ gather_devices = ["/replica:0/task:0/device:CPU:0"
+ for _ in range(num_shards)]
+ return lambda x, un_op: ar.build_shuffle_all_reduce(
+ x, gather_devices, math_ops.add_n, un_op)
+
+ def _testShuffleAllReduce(self, num_workers, num_gpus, shape, num_shards):
+ start_time = time.time()
+ build_f = self._buildShuffle(num_workers, num_gpus, num_shards)
+ self._testAllReduce(num_workers, num_gpus, shape, build_f)
+ elapsed = time.time() - start_time
+ tf_logging.info("ShuffleAllReduce num_workers=%d num_gpus=%d shape=%s "
+ "elapsed=%f" % (num_workers, num_gpus, shape, elapsed))
+
+ def testShuffleAllReduce(self):
+ self._testShuffleAllReduce(1, 2, [8], 1)
+ self._testShuffleAllReduce(1, 2, [4, 4], 1)
+ self._testShuffleAllReduce(1, 8, [32], 1)
+ self._testShuffleAllReduce(1, 8, [120], 1)
+ self._testShuffleAllReduce(2, 8, [7, 13], 3)
+ self._testShuffleAllReduce(2, 8, [8, 8], 2)
+ self._testShuffleAllReduce(2, 8, [8, 8], 4)
+ self._testShuffleAllReduce(4, 8, [8, 8, 2], 4)
+
+ def _buildRecursiveHD(self, num_workers, num_gpus):
+ return lambda x, un_op: ar.build_recursive_hd_all_reduce(
+ x, math_ops.add, un_op)
+
+ # pylint: enable=g-long-lambda
+
+ def _testRecursiveHDAllReduce(self, num_workers, num_gpus, shape):
+ start_time = time.time()
+ build_f = self._buildRecursiveHD(num_workers, num_gpus)
+ self._testAllReduce(num_workers, num_gpus, shape, build_f)
+ elapsed = time.time() - start_time
+ tf_logging.info("RecursiveHDAllReduce num_workers=%d num_gpus=%d "
+ "shape=%s elapsed=%f" %
+ (num_workers, num_gpus, shape, elapsed))
+
+ def testRecursiveHDAllReduce(self):
+ self._testRecursiveHDAllReduce(1, 2, [8])
+ self._testRecursiveHDAllReduce(1, 2, [4, 4])
+ self._testRecursiveHDAllReduce(1, 8, [32])
+ self._testRecursiveHDAllReduce(1, 8, [120])
+ self._testRecursiveHDAllReduce(2, 8, [8, 8])
+ self._testRecursiveHDAllReduce(4, 8, [8, 8, 2])
+
+
+if __name__ == "__main__":
+ test.main()