aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/all_reduce
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-02-12 10:34:20 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-12 10:41:59 -0800
commita55c133748b695778163035f098cf3ac34302872 (patch)
tree585de69bf74dce55ab8bc65480931dd6dd824456 /tensorflow/contrib/all_reduce
parent7ecc83de23df629fcebb541262925e34dd17cc84 (diff)
Add support for scalars in `tf.contrib.all_reduce`.
PiperOrigin-RevId: 185398372
Diffstat (limited to 'tensorflow/contrib/all_reduce')
-rw-r--r--tensorflow/contrib/all_reduce/python/all_reduce.py12
-rw-r--r--tensorflow/contrib/all_reduce/python/all_reduce_test.py4
2 files changed, 9 insertions, 7 deletions
diff --git a/tensorflow/contrib/all_reduce/python/all_reduce.py b/tensorflow/contrib/all_reduce/python/all_reduce.py
index 28f60b3499..16617b7266 100644
--- a/tensorflow/contrib/all_reduce/python/all_reduce.py
+++ b/tensorflow/contrib/all_reduce/python/all_reduce.py
@@ -48,7 +48,7 @@ def _flatten_tensors(tensors):
if shape.ndims is None:
raise ValueError("At least one of the tensors in 'tensors' must have "
"statically known rank.")
- if len(shape) > 1:
+ if len(shape) != 1:
reshaped = []
for t in tensors:
with ops.colocate_with(t):
@@ -289,7 +289,7 @@ def build_ring_all_reduce(input_tensors, num_workers, num_subchunks,
chunks_by_dev)
if pad_len > 0:
output_tensors = _strip_padding(output_tensors, pad_len)
- if len(shape) > 1:
+ if len(shape) != 1:
output_tensors = _reshape_tensors(output_tensors, shape)
return output_tensors
@@ -466,7 +466,7 @@ def build_recursive_hd_all_reduce(input_tensors, red_op, un_op=None):
if un_op:
reduced_shards = [un_op(t) for t in reduced_shards]
output_tensors = _build_recursive_hd_scatter(reduced_shards, devices)
- if len(shape) > 1:
+ if len(shape) != 1:
output_tensors = _reshape_tensors(output_tensors, shape)
return output_tensors
@@ -578,7 +578,7 @@ def build_shuffle_all_reduce(input_tensors, gather_devices, red_op, un_op=None):
reduced_shards = _build_shuffle_gather(input_tensors, gather_devices,
red_op, un_op)
output_tensors = _build_shuffle_scatter(reduced_shards, dst_devices)
- if len(shape) > 1:
+ if len(shape) != 1:
output_tensors = _reshape_tensors(output_tensors, shape)
return output_tensors
@@ -752,7 +752,7 @@ def _build_nccl_hybrid(input_tensors, red_op, upper_level_f):
dst_tensors.append(array_ops.identity(broadcast_src))
down_values[w] = dst_tensors
output_tensors = [v for sublist in down_values for v in sublist]
- if len(shape) > 1:
+ if len(shape) != 1:
output_tensors = _reshape_tensors(output_tensors, shape)
return output_tensors
@@ -831,7 +831,7 @@ def _build_shuffle_hybrid(input_tensors, gather_devices, red_op, upper_level_f):
for w in range(0, num_workers):
output_tensors += _build_shuffle_scatter(
[level_2_output[w]], per_worker_devices[w])
- if len(shape) > 1:
+ if len(shape) != 1:
output_tensors = _reshape_tensors(output_tensors, shape)
return output_tensors
diff --git a/tensorflow/contrib/all_reduce/python/all_reduce_test.py b/tensorflow/contrib/all_reduce/python/all_reduce_test.py
index 0802b27369..47bab0a367 100644
--- a/tensorflow/contrib/all_reduce/python/all_reduce_test.py
+++ b/tensorflow/contrib/all_reduce/python/all_reduce_test.py
@@ -119,7 +119,7 @@ class AllReduceTest(test_util.TensorFlowTestCase):
def _buildInitialVars(self, shape, dev_list):
values = []
num_devices = len(dev_list)
- dim = np.prod(shape)
+ dim = np.prod(shape) if shape else 1
for d in range(0, num_devices):
with ops.device(dev_list[d]):
npt = np.zeros(shape).astype(np.float32)
@@ -164,6 +164,7 @@ class AllReduceTest(test_util.TensorFlowTestCase):
(num_workers, num_gpus, shape, subdiv, elapsed))
def testRingAllReduce(self):
+ self._testRingAllReduce(1, 2, [], 1)
self._testRingAllReduce(1, 2, [8], 1)
self._testRingAllReduce(1, 2, [4, 4], 1)
self._testRingAllReduce(6, 1, [8], 1)
@@ -192,6 +193,7 @@ class AllReduceTest(test_util.TensorFlowTestCase):
"elapsed=%f" % (num_workers, num_gpus, shape, elapsed))
def testShuffleAllReduce(self):
+ self._testShuffleAllReduce(1, 2, [], 1)
self._testShuffleAllReduce(1, 2, [8], 1)
self._testShuffleAllReduce(1, 2, [4, 4], 1)
self._testShuffleAllReduce(1, 8, [32], 1)