aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/all_reduce
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/all_reduce')
-rw-r--r--tensorflow/contrib/all_reduce/python/all_reduce.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/tensorflow/contrib/all_reduce/python/all_reduce.py b/tensorflow/contrib/all_reduce/python/all_reduce.py
index 8e7f1791b8..22d7633ce2 100644
--- a/tensorflow/contrib/all_reduce/python/all_reduce.py
+++ b/tensorflow/contrib/all_reduce/python/all_reduce.py
@@ -762,6 +762,8 @@ def _reduce_non_singleton(input_tensors, red_f, un_op):
if len(input_tensors) > 1:
return red_f(input_tensors)
else:
+ if not un_op:
+ return input_tensors
output_tensors = []
for t in input_tensors:
with ops.colocate_with(t):
@@ -835,7 +837,7 @@ def _build_shuffle_hybrid(input_tensors, gather_devices, red_op, upper_level_f):
def build_shuffle_then_ring(input_tensors, gather_devices, subdiv,
- red_n_op, red_op, un_op):
+ red_n_op, red_op, un_op=None):
"""Construct hybrid of Shuffle within workers, Ring across workers."""
def upper_builder(tensors):
return build_ring_all_reduce(tensors, len(tensors), subdiv, [0],