diff options
Diffstat (limited to 'tensorflow/python/training/distribute.py')
-rw-r--r-- | tensorflow/python/training/distribute.py | 51 |
1 files changed, 31 insertions, 20 deletions
diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py index 419a9ec12b..a92a1bdee7 100644 --- a/tensorflow/python/training/distribute.py +++ b/tensorflow/python/training/distribute.py @@ -26,7 +26,6 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops.losses import losses_impl from tensorflow.python.platform import tf_logging @@ -807,15 +806,22 @@ class DistributionStrategy(object): var: Variable, possibly mirrored to multiple devices, to operate on. fn: Function to call. Should take the variable as the first argument. *args: Additional positional arguments to pass to `fn()`. - **kwargs: Keyword arguments to pass to `fn()`. + **kwargs: Keyword arguments to pass to `fn()`. If "grouped=False" is + specified, the return value will be unwrapped. Returns: - Merged return value of `fn` across all towers. + By default, the merged return value of `fn` across all towers. The merged + result has dependencies to make sure that if it is evaluated at all, the + side effects (updates) will happen on every tower. If instead + "grouped=False" is specified, this function will return a nest of lists + where each list has an element per tower, and the caller is responsible + for ensuring all elements are executed. """ _require_cross_tower_context(self) - return self._update(var, fn, *args, **kwargs) + options = {"grouped": kwargs.pop("grouped", True)} + return self._update(var, options, fn, *args, **kwargs) - def _update(self, var, fn, *args, **kwargs): + def _update(self, var, options, fn, *args, **kwargs): raise NotImplementedError("must be implemented in descendants") def update_non_slot(self, colocate_with, fn, *args, **kwargs): @@ -825,15 +831,18 @@ class DistributionStrategy(object): colocate_with: The return value of `non_slot_devices()`. fn: Function to execute. *args: Positional arguments to pass to `fn()`. - **kwargs: Keyword arguments to pass to `fn()`. + **kwargs: Keyword arguments to pass to `fn()`. If "grouped=False" is + specified, the return value will be unwrapped and the caller is + responsible for ensuring all elements are executed. Returns: Return value of `fn`, possibly merged across devices. """ _require_cross_tower_context(self) - return self._update_non_slot(colocate_with, fn, *args, **kwargs) + options = {"grouped": kwargs.pop("grouped", True)} + return self._update_non_slot(colocate_with, options, fn, *args, **kwargs) - def _update_non_slot(self, colocate_with, fn, *args, **kwargs): + def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs): raise NotImplementedError("must be implemented in descendants") def unwrap(self, value): @@ -1134,17 +1143,22 @@ class _DefaultDistributionStrategy(DistributionStrategy): del aggregation, destinations return value - def _update(self, var, fn, *args, **kwargs): - # TODO(josh11b): Figure out what we should be passing to UpdateContext() - # once that value is used for something. - with ops.colocate_with(var), UpdateContext(var): - return fn(var, *args, **kwargs) + def _update(self, var, options, fn, *args, **kwargs): + # The implementations of _update() and _update_non_slot() are identical + # except _update() passes `var` as the first argument to `fn()`. + return self._update_non_slot(var, options, fn, var, *args, **kwargs) - def _update_non_slot(self, colocate_with, fn, *args, **kwargs): + def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs): + should_group = options.pop("grouped") + assert not options # Validate that we are processing all of the options. # TODO(josh11b): Figure out what we should be passing to UpdateContext() # once that value is used for something. with ops.colocate_with(colocate_with), UpdateContext(colocate_with): - return fn(*args, **kwargs) + result = fn(*args, **kwargs) + if should_group: + return result + else: + return nest.map_structure(self._unwrap, result) def read_var(self, tower_local_var): return array_ops.identity(tower_local_var) @@ -1193,13 +1207,10 @@ class _DefaultDistributionStrategy(DistributionStrategy): def increment_var(v, amount=1): """`v += amount`, distributed-aware version.""" def update(vu): - if isinstance(vu, resource_variable_ops.ResourceVariable): - return vu.assign_add(amount, read_value=False) - else: - return state_ops.assign_add(vu, amount) + return vu.assign_add(amount, read_value=False) def merge_fn(dist, vm): - return dist.group(dist.update(vm, update)) + return dist.update(vm, update) tower_context = distribution_strategy_context.get_tower_context() return tower_context.merge_call(merge_fn, v) |