aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/distribute.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/distribute.py')
-rw-r--r--tensorflow/python/training/distribute.py51
1 files changed, 31 insertions, 20 deletions
diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py
index 419a9ec12b..a92a1bdee7 100644
--- a/tensorflow/python/training/distribute.py
+++ b/tensorflow/python/training/distribute.py
@@ -26,7 +26,6 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import resource_variable_ops
-from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.losses import losses_impl
from tensorflow.python.platform import tf_logging
@@ -807,15 +806,22 @@ class DistributionStrategy(object):
var: Variable, possibly mirrored to multiple devices, to operate on.
fn: Function to call. Should take the variable as the first argument.
*args: Additional positional arguments to pass to `fn()`.
- **kwargs: Keyword arguments to pass to `fn()`.
+ **kwargs: Keyword arguments to pass to `fn()`. If "grouped=False" is
+ specified, the return value will be unwrapped.
Returns:
- Merged return value of `fn` across all towers.
+ By default, the merged return value of `fn` across all towers. The merged
+ result has dependencies to make sure that if it is evaluated at all, the
+ side effects (updates) will happen on every tower. If instead
+ "grouped=False" is specified, this function will return a nest of lists
+ where each list has an element per tower, and the caller is responsible
+ for ensuring all elements are executed.
"""
_require_cross_tower_context(self)
- return self._update(var, fn, *args, **kwargs)
+ options = {"grouped": kwargs.pop("grouped", True)}
+ return self._update(var, options, fn, *args, **kwargs)
- def _update(self, var, fn, *args, **kwargs):
+ def _update(self, var, options, fn, *args, **kwargs):
raise NotImplementedError("must be implemented in descendants")
def update_non_slot(self, colocate_with, fn, *args, **kwargs):
@@ -825,15 +831,18 @@ class DistributionStrategy(object):
colocate_with: The return value of `non_slot_devices()`.
fn: Function to execute.
*args: Positional arguments to pass to `fn()`.
- **kwargs: Keyword arguments to pass to `fn()`.
+ **kwargs: Keyword arguments to pass to `fn()`. If "grouped=False" is
+ specified, the return value will be unwrapped and the caller is
+ responsible for ensuring all elements are executed.
Returns:
Return value of `fn`, possibly merged across devices.
"""
_require_cross_tower_context(self)
- return self._update_non_slot(colocate_with, fn, *args, **kwargs)
+ options = {"grouped": kwargs.pop("grouped", True)}
+ return self._update_non_slot(colocate_with, options, fn, *args, **kwargs)
- def _update_non_slot(self, colocate_with, fn, *args, **kwargs):
+ def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs):
raise NotImplementedError("must be implemented in descendants")
def unwrap(self, value):
@@ -1134,17 +1143,22 @@ class _DefaultDistributionStrategy(DistributionStrategy):
del aggregation, destinations
return value
- def _update(self, var, fn, *args, **kwargs):
- # TODO(josh11b): Figure out what we should be passing to UpdateContext()
- # once that value is used for something.
- with ops.colocate_with(var), UpdateContext(var):
- return fn(var, *args, **kwargs)
+ def _update(self, var, options, fn, *args, **kwargs):
+ # The implementations of _update() and _update_non_slot() are identical
+ # except _update() passes `var` as the first argument to `fn()`.
+ return self._update_non_slot(var, options, fn, var, *args, **kwargs)
- def _update_non_slot(self, colocate_with, fn, *args, **kwargs):
+ def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs):
+ should_group = options.pop("grouped")
+ assert not options # Validate that we are processing all of the options.
# TODO(josh11b): Figure out what we should be passing to UpdateContext()
# once that value is used for something.
with ops.colocate_with(colocate_with), UpdateContext(colocate_with):
- return fn(*args, **kwargs)
+ result = fn(*args, **kwargs)
+ if should_group:
+ return result
+ else:
+ return nest.map_structure(self._unwrap, result)
def read_var(self, tower_local_var):
return array_ops.identity(tower_local_var)
@@ -1193,13 +1207,10 @@ class _DefaultDistributionStrategy(DistributionStrategy):
def increment_var(v, amount=1):
"""`v += amount`, distributed-aware version."""
def update(vu):
- if isinstance(vu, resource_variable_ops.ResourceVariable):
- return vu.assign_add(amount, read_value=False)
- else:
- return state_ops.assign_add(vu, amount)
+ return vu.assign_add(amount, read_value=False)
def merge_fn(dist, vm):
- return dist.group(dist.update(vm, update))
+ return dist.update(vm, update)
tower_context = distribution_strategy_context.get_tower_context()
return tower_context.merge_call(merge_fn, v)