aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/all_reduce
diff options
context:
space:
mode:
authorGravatar Justin Lebar <jlebar@google.com>2017-10-06 11:34:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-06 11:37:55 -0700
commit84b579e1d14760fc2a313c8e1d7ca100f74945a1 (patch)
tree0d1ff16bbf650a4924c81ad9ed53a18a1d584f5c /tensorflow/contrib/all_reduce
parent549e651106e1e582dad0e8a6ea57b8f59ce95067 (diff)
[XLA:CPU] Make EmitTargetAddressForOp return void (well, technically Status).
This is a general cleanup -- less repeated code -- but it's also part of an effort to use IrArray more and llvm::Value less. In particular, many callsites would take the llvm::Value returned by EmitTargetAddressForOp and create an IrArray out of it, but then never attach AA info to that array. Having this function return void forces you to call GetIrArrayForOp(), which attaches the AA metadata appropriately. This change also gets rid of an unused arg to EmitTargetAddressForOp. PiperOrigin-RevId: 171320201
Diffstat (limited to 'tensorflow/contrib/all_reduce')
-rw-r--r--tensorflow/contrib/all_reduce/python/all_reduce.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/tensorflow/contrib/all_reduce/python/all_reduce.py b/tensorflow/contrib/all_reduce/python/all_reduce.py
index 8e7f1791b8..22d7633ce2 100644
--- a/tensorflow/contrib/all_reduce/python/all_reduce.py
+++ b/tensorflow/contrib/all_reduce/python/all_reduce.py
@@ -762,6 +762,8 @@ def _reduce_non_singleton(input_tensors, red_f, un_op):
if len(input_tensors) > 1:
return red_f(input_tensors)
else:
+ if not un_op:
+ return input_tensors
output_tensors = []
for t in input_tensors:
with ops.colocate_with(t):
@@ -835,7 +837,7 @@ def _build_shuffle_hybrid(input_tensors, gather_devices, red_op, upper_level_f):
def build_shuffle_then_ring(input_tensors, gather_devices, subdiv,
- red_n_op, red_op, un_op):
+ red_n_op, red_op, un_op=None):
"""Construct hybrid of Shuffle within workers, Ring across workers."""
def upper_builder(tensors):
return build_ring_all_reduce(tensors, len(tensors), subdiv, [0],