diff options
author | 2018-08-15 20:51:09 -0700 | |
---|---|---|
committer | 2018-08-15 20:57:39 -0700 | |
commit | ff6525acf1bdf8afe65f4fd93047669f08f5e061 (patch) | |
tree | 0fb628f68d371d55bcb4863f019cba11dc5d6307 | |
parent | f7023480452e4b4781d343acf76ae720540b1423 (diff) |
Step_fn should be able to receive un-wrapped inputs
PiperOrigin-RevId: 208929959
7 files changed, 31 insertions, 19 deletions
diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py index aa7a61bb3b..516ede7ade 100644 --- a/tensorflow/contrib/distribute/python/minimize_loss_test.py +++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py @@ -56,11 +56,11 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): model_fn, dataset_fn, layer = minimize_loss_example( optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss) - def step_fn(ctx, inputs): + def step_fn(ctx, *inputs): del ctx # Unused return distribution.group( distribution.call_for_each_tower( - model_fn, inputs, run_concurrently=layer.built)) + model_fn, *inputs, run_concurrently=layer.built)) iterator = distribution.distribute_dataset( dataset_fn).make_one_shot_iterator() @@ -153,11 +153,11 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): use_callable_loss=True, create_optimizer_inside_model_fn=True) - def step_fn(ctx, inputs): + def step_fn(ctx, *inputs): del ctx # Unused return distribution.group( distribution.call_for_each_tower( - model_fn, inputs, run_concurrently=layer.built)) + model_fn, *inputs, run_concurrently=layer.built)) iterator = distribution.distribute_dataset( dataset_fn).make_one_shot_iterator() @@ -231,11 +231,11 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): if isinstance(distribution, mirrored_strategy.MirroredStrategy): self.assertFalse(distribution._prefetch_on_device) - def step_fn(ctx, inputs): + def step_fn(ctx, *inputs): del ctx # Unused fetches = distribution.unwrap( distribution.call_for_each_tower( - model_fn, inputs, run_concurrently=batchnorm.built)) + model_fn, *inputs, run_concurrently=batchnorm.built)) if update_ops_in_cross_tower_mode: fetches += ops.get_collection(ops.GraphKeys.UPDATE_OPS) return control_flow_ops.group(fetches) @@ -328,9 +328,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): labels = dataset_ops.Dataset.from_tensors([[6.], [21.]]) return dataset_ops.Dataset.zip((features, labels)).repeat() - def step_fn(ctx, inputs): + def step_fn(ctx, x, y): del ctx # Unused - x, y = inputs return distribution.group( distribution.call_for_each_tower( model_fn, x, y, run_concurrently=False)) @@ -417,9 +416,9 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): output_context.set_non_tensor_output(key1, value1) return (train_op, loss) - def step_fn(output_context, inputs): + def step_fn(output_context, *inputs): (train_op, loss) = distribution.call_for_each_tower( - model_fn, output_context, inputs, run_concurrently=False) + model_fn, output_context, *inputs, run_concurrently=False) output_context.set_last_step_output( name="cross_tower_loss_agg", output=loss, diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index e3376a0636..72d1c6b7dd 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -372,7 +372,10 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): def body(i, *args): """A wrapper around `fn` to create the while loop body.""" del args - fn_result = fn(ctx, iterator.get_next()) + fn_inputs = iterator.get_next() + if not isinstance(fn_inputs, tuple): + fn_inputs = (fn_inputs,) + fn_result = fn(ctx, *fn_inputs) for (name, output) in ctx.last_step_outputs.items(): # Convert all outputs to tensors, potentially from `DistributedValues`. ctx.last_step_outputs[name] = self.unwrap(output) diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py index 016978cdb3..86833ad851 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -80,7 +80,10 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): def body(i, *args): """A wrapper around `fn` to create the while loop body.""" del args - fn_result = fn(ctx, iterator.get_next()) + fn_inputs = iterator.get_next() + if not isinstance(fn_inputs, tuple): + fn_inputs = (fn_inputs,) + fn_result = fn(ctx, *fn_inputs) flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) with ops.control_dependencies([fn_result]): return [i + 1] + flat_last_step_outputs diff --git a/tensorflow/contrib/distribute/python/step_fn.py b/tensorflow/contrib/distribute/python/step_fn.py index d3611570b4..1b5a4f64e5 100644 --- a/tensorflow/contrib/distribute/python/step_fn.py +++ b/tensorflow/contrib/distribute/python/step_fn.py @@ -90,14 +90,14 @@ class StandardSingleLossStep(StandardInputStep): def __call__(self): with self._distribution.scope(): - def step_fn(ctx, inputs): + def step_fn(ctx, *inputs): """Function to run one iteration with one input.""" gradients_fn = backprop.implicit_grad(self._loss_fn) gradients_fn = optimizer_lib.get_filtered_grad_fn(gradients_fn) grads_and_vars = self.distribution.call_for_each_tower( gradients_fn, - ctx, inputs, + ctx, *inputs, run_concurrently=self._is_run_concurrently) # If threads use layers, then we need to run the first step # sequentially, so that layers.build() is not executed in parallel. diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index b510fdb888..3f8a0922de 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -144,7 +144,10 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): ctx = values.MultiStepContext() def run_fn(*args, **kwargs): del args, kwargs - fn_result = fn(ctx, dequeue_fn()) + fn_inputs = dequeue_fn() + if not isinstance(fn_inputs, tuple): + fn_inputs = (fn_inputs,) + fn_result = fn(ctx, *fn_inputs) flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) if flat_last_step_outputs: with ops.control_dependencies([fn_result]): diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index eab608813b..52002fd79b 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -1245,9 +1245,8 @@ class Estimator(object): self._train_distribution.read_var(global_step_tensor)) # Create a step_fn from the train_op of grouped_estimator_spec - def step_fn(ctx, inputs): + def step_fn(ctx, features, labels): """A single step that is passed to run_on_dataset.""" - features, labels = inputs estimator_spec = self._train_distribution.call_for_each_tower( self._call_model_fn, features, diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py index 28c60ad809..20e031569b 100644 --- a/tensorflow/python/training/distribute.py +++ b/tensorflow/python/training/distribute.py @@ -624,13 +624,18 @@ class DistributionStrategy(object): Args: fn: function to run using this distribution strategy. The function must - have the following signature: def fn(context, inputs). + have the following signature: def fn(context, *inputs). `context` is an instance of `MultiStepContext` that will be passed when `fn` is run. `context` can be used to specify the outputs to be returned from `fn` by calling `context.set_last_step_output`. It can also be used to capture non tensor outputs by `context.set_non_tensor_output`. See `MultiStepContext` documentation for more information. - `inputs` will have same type/structure as `iterator.get_next()`. + `inputs` will have same type/structure as `iterator.get_next()`. If the + `iterator.get_next()` returns a tuple say `return x, y` then whose will + be unpacked and passed to the `step_fn`; and step_fn signature would + look like `def step_fn(context, x, y)`. If the iterator returns a single + value say `return x` then the value is passed as is; the step_fn + signature would look like `def step_fn(context, x)`. Typically, `fn` will use `call_for_each_tower` method of the strategy to distribute the computation over multiple towers. iterator: Iterator of a dataset that represents the input for `fn`. The |