diff options
author | Sourabh Bajaj <sourabhbajaj@google.com> | 2018-08-15 20:51:09 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-15 20:57:39 -0700 |
commit | ff6525acf1bdf8afe65f4fd93047669f08f5e061 (patch) | |
tree | 0fb628f68d371d55bcb4863f019cba11dc5d6307 /tensorflow/contrib/distribute/python/mirrored_strategy.py | |
parent | f7023480452e4b4781d343acf76ae720540b1423 (diff) |
Step_fn should be able to receive un-wrapped inputs
PiperOrigin-RevId: 208929959
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/mirrored_strategy.py | 5 |
1 files changed, 4 insertions, 1 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index e3376a0636..72d1c6b7dd 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -372,7 +372,10 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): def body(i, *args): """A wrapper around `fn` to create the while loop body.""" del args - fn_result = fn(ctx, iterator.get_next()) + fn_inputs = iterator.get_next() + if not isinstance(fn_inputs, tuple): + fn_inputs = (fn_inputs,) + fn_result = fn(ctx, *fn_inputs) for (name, output) in ctx.last_step_outputs.items(): # Convert all outputs to tensors, potentially from `DistributedValues`. ctx.last_step_outputs[name] = self.unwrap(output) |