aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Sourabh Bajaj <sourabhbajaj@google.com>2018-08-15 20:51:09 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-15 20:57:39 -0700
commitff6525acf1bdf8afe65f4fd93047669f08f5e061 (patch)
tree0fb628f68d371d55bcb4863f019cba11dc5d6307
parentf7023480452e4b4781d343acf76ae720540b1423 (diff)
Step_fn should be able to receive un-wrapped inputs
PiperOrigin-RevId: 208929959
-rw-r--r--tensorflow/contrib/distribute/python/minimize_loss_test.py19
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py5
-rw-r--r--tensorflow/contrib/distribute/python/one_device_strategy.py5
-rw-r--r--tensorflow/contrib/distribute/python/step_fn.py4
-rw-r--r--tensorflow/contrib/distribute/python/tpu_strategy.py5
-rw-r--r--tensorflow/python/estimator/estimator.py3
-rw-r--r--tensorflow/python/training/distribute.py9
7 files changed, 31 insertions, 19 deletions
diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py
index aa7a61bb3b..516ede7ade 100644
--- a/tensorflow/contrib/distribute/python/minimize_loss_test.py
+++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py
@@ -56,11 +56,11 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
model_fn, dataset_fn, layer = minimize_loss_example(
optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss)
- def step_fn(ctx, inputs):
+ def step_fn(ctx, *inputs):
del ctx # Unused
return distribution.group(
distribution.call_for_each_tower(
- model_fn, inputs, run_concurrently=layer.built))
+ model_fn, *inputs, run_concurrently=layer.built))
iterator = distribution.distribute_dataset(
dataset_fn).make_one_shot_iterator()
@@ -153,11 +153,11 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
use_callable_loss=True,
create_optimizer_inside_model_fn=True)
- def step_fn(ctx, inputs):
+ def step_fn(ctx, *inputs):
del ctx # Unused
return distribution.group(
distribution.call_for_each_tower(
- model_fn, inputs, run_concurrently=layer.built))
+ model_fn, *inputs, run_concurrently=layer.built))
iterator = distribution.distribute_dataset(
dataset_fn).make_one_shot_iterator()
@@ -231,11 +231,11 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
if isinstance(distribution, mirrored_strategy.MirroredStrategy):
self.assertFalse(distribution._prefetch_on_device)
- def step_fn(ctx, inputs):
+ def step_fn(ctx, *inputs):
del ctx # Unused
fetches = distribution.unwrap(
distribution.call_for_each_tower(
- model_fn, inputs, run_concurrently=batchnorm.built))
+ model_fn, *inputs, run_concurrently=batchnorm.built))
if update_ops_in_cross_tower_mode:
fetches += ops.get_collection(ops.GraphKeys.UPDATE_OPS)
return control_flow_ops.group(fetches)
@@ -328,9 +328,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
labels = dataset_ops.Dataset.from_tensors([[6.], [21.]])
return dataset_ops.Dataset.zip((features, labels)).repeat()
- def step_fn(ctx, inputs):
+ def step_fn(ctx, x, y):
del ctx # Unused
- x, y = inputs
return distribution.group(
distribution.call_for_each_tower(
model_fn, x, y, run_concurrently=False))
@@ -417,9 +416,9 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
output_context.set_non_tensor_output(key1, value1)
return (train_op, loss)
- def step_fn(output_context, inputs):
+ def step_fn(output_context, *inputs):
(train_op, loss) = distribution.call_for_each_tower(
- model_fn, output_context, inputs, run_concurrently=False)
+ model_fn, output_context, *inputs, run_concurrently=False)
output_context.set_last_step_output(
name="cross_tower_loss_agg",
output=loss,
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index e3376a0636..72d1c6b7dd 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -372,7 +372,10 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
def body(i, *args):
"""A wrapper around `fn` to create the while loop body."""
del args
- fn_result = fn(ctx, iterator.get_next())
+ fn_inputs = iterator.get_next()
+ if not isinstance(fn_inputs, tuple):
+ fn_inputs = (fn_inputs,)
+ fn_result = fn(ctx, *fn_inputs)
for (name, output) in ctx.last_step_outputs.items():
# Convert all outputs to tensors, potentially from `DistributedValues`.
ctx.last_step_outputs[name] = self.unwrap(output)
diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py
index 016978cdb3..86833ad851 100644
--- a/tensorflow/contrib/distribute/python/one_device_strategy.py
+++ b/tensorflow/contrib/distribute/python/one_device_strategy.py
@@ -80,7 +80,10 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy):
def body(i, *args):
"""A wrapper around `fn` to create the while loop body."""
del args
- fn_result = fn(ctx, iterator.get_next())
+ fn_inputs = iterator.get_next()
+ if not isinstance(fn_inputs, tuple):
+ fn_inputs = (fn_inputs,)
+ fn_result = fn(ctx, *fn_inputs)
flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
with ops.control_dependencies([fn_result]):
return [i + 1] + flat_last_step_outputs
diff --git a/tensorflow/contrib/distribute/python/step_fn.py b/tensorflow/contrib/distribute/python/step_fn.py
index d3611570b4..1b5a4f64e5 100644
--- a/tensorflow/contrib/distribute/python/step_fn.py
+++ b/tensorflow/contrib/distribute/python/step_fn.py
@@ -90,14 +90,14 @@ class StandardSingleLossStep(StandardInputStep):
def __call__(self):
with self._distribution.scope():
- def step_fn(ctx, inputs):
+ def step_fn(ctx, *inputs):
"""Function to run one iteration with one input."""
gradients_fn = backprop.implicit_grad(self._loss_fn)
gradients_fn = optimizer_lib.get_filtered_grad_fn(gradients_fn)
grads_and_vars = self.distribution.call_for_each_tower(
gradients_fn,
- ctx, inputs,
+ ctx, *inputs,
run_concurrently=self._is_run_concurrently)
# If threads use layers, then we need to run the first step
# sequentially, so that layers.build() is not executed in parallel.
diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py
index b510fdb888..3f8a0922de 100644
--- a/tensorflow/contrib/distribute/python/tpu_strategy.py
+++ b/tensorflow/contrib/distribute/python/tpu_strategy.py
@@ -144,7 +144,10 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
ctx = values.MultiStepContext()
def run_fn(*args, **kwargs):
del args, kwargs
- fn_result = fn(ctx, dequeue_fn())
+ fn_inputs = dequeue_fn()
+ if not isinstance(fn_inputs, tuple):
+ fn_inputs = (fn_inputs,)
+ fn_result = fn(ctx, *fn_inputs)
flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
if flat_last_step_outputs:
with ops.control_dependencies([fn_result]):
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index eab608813b..52002fd79b 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -1245,9 +1245,8 @@ class Estimator(object):
self._train_distribution.read_var(global_step_tensor))
# Create a step_fn from the train_op of grouped_estimator_spec
- def step_fn(ctx, inputs):
+ def step_fn(ctx, features, labels):
"""A single step that is passed to run_on_dataset."""
- features, labels = inputs
estimator_spec = self._train_distribution.call_for_each_tower(
self._call_model_fn,
features,
diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py
index 28c60ad809..20e031569b 100644
--- a/tensorflow/python/training/distribute.py
+++ b/tensorflow/python/training/distribute.py
@@ -624,13 +624,18 @@ class DistributionStrategy(object):
Args:
fn: function to run using this distribution strategy. The function must
- have the following signature: def fn(context, inputs).
+ have the following signature: def fn(context, *inputs).
`context` is an instance of `MultiStepContext` that will be passed when
`fn` is run. `context` can be used to specify the outputs to be returned
from `fn` by calling `context.set_last_step_output`. It can also be used
to capture non tensor outputs by `context.set_non_tensor_output`.
See `MultiStepContext` documentation for more information.
- `inputs` will have same type/structure as `iterator.get_next()`.
+ `inputs` will have same type/structure as `iterator.get_next()`. If the
+ `iterator.get_next()` returns a tuple say `return x, y` then whose will
+ be unpacked and passed to the `step_fn`; and step_fn signature would
+ look like `def step_fn(context, x, y)`. If the iterator returns a single
+ value say `return x` then the value is passed as is; the step_fn
+ signature would look like `def step_fn(context, x)`.
Typically, `fn` will use `call_for_each_tower` method of the strategy
to distribute the computation over multiple towers.
iterator: Iterator of a dataset that represents the input for `fn`. The