aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/mirrored_strategy.py
diff options
context:
space:
mode:
authorGravatar Sourabh Bajaj <sourabhbajaj@google.com>2018-08-15 20:51:09 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-15 20:57:39 -0700
commitff6525acf1bdf8afe65f4fd93047669f08f5e061 (patch)
tree0fb628f68d371d55bcb4863f019cba11dc5d6307 /tensorflow/contrib/distribute/python/mirrored_strategy.py
parentf7023480452e4b4781d343acf76ae720540b1423 (diff)
Step_fn should be able to receive un-wrapped inputs
PiperOrigin-RevId: 208929959
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy.py')
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py5
1 files changed, 4 insertions, 1 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index e3376a0636..72d1c6b7dd 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -372,7 +372,10 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
def body(i, *args):
"""A wrapper around `fn` to create the while loop body."""
del args
- fn_result = fn(ctx, iterator.get_next())
+ fn_inputs = iterator.get_next()
+ if not isinstance(fn_inputs, tuple):
+ fn_inputs = (fn_inputs,)
+ fn_result = fn(ctx, *fn_inputs)
for (name, output) in ctx.last_step_outputs.items():
# Convert all outputs to tensors, potentially from `DistributedValues`.
ctx.last_step_outputs[name] = self.unwrap(output)