diff options
Diffstat (limited to 'tensorflow/contrib/all_reduce')
-rw-r--r-- | tensorflow/contrib/all_reduce/python/all_reduce.py | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/tensorflow/contrib/all_reduce/python/all_reduce.py b/tensorflow/contrib/all_reduce/python/all_reduce.py index 8e7f1791b8..22d7633ce2 100644 --- a/tensorflow/contrib/all_reduce/python/all_reduce.py +++ b/tensorflow/contrib/all_reduce/python/all_reduce.py @@ -762,6 +762,8 @@ def _reduce_non_singleton(input_tensors, red_f, un_op): if len(input_tensors) > 1: return red_f(input_tensors) else: + if not un_op: + return input_tensors output_tensors = [] for t in input_tensors: with ops.colocate_with(t): @@ -835,7 +837,7 @@ def _build_shuffle_hybrid(input_tensors, gather_devices, red_op, upper_level_f): def build_shuffle_then_ring(input_tensors, gather_devices, subdiv, - red_n_op, red_op, un_op): + red_n_op, red_op, un_op=None): """Construct hybrid of Shuffle within workers, Ring across workers.""" def upper_builder(tensors): return build_ring_all_reduce(tensors, len(tensors), subdiv, [0], |