aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-01 16:31:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-01 16:35:32 -0700
commitbb1f9e1a57c8bc18325b3c86298be96e6647a0a3 (patch)
tree66809f6708a49a5b8d408ef0b349fe837d905cda /tensorflow/python/training
parent49bbfec04b729960999ef054e3acab719631b101 (diff)
Change semantics of DistributionStrategy.update() to make sure the
output depends on the updates across all mirrors. Before this change, update() would return a Mirrored value that where each component was an update to a single mirror. This caused a problem since for reading purposes other DistributionStrategy methods would consider it okay to read any single component, and so if you for example did something like session.run(strategy.update(...)) it would only perform the update on one replica. The fix is to have the output be a Mirrored value that is actually the identity operation returning the output on that device, but that has a control dependency making sure that the update actually happens on all the replicas. This fix was already present in MirroredVariable._assign_func, this CL moves the fix into update() and generalizes it to multiple return values. To disable this new grouping behavior, you may now pass "grouped=False" to update(). For example, some callers (like Optimizer) are performing a lot of updates and they prefer to group all of them together at once for performance reasons. In this case, we still want to make sure the caller executes the update on all replicas, so we return an unwrapped value instead of a Mirrored value. This has the happy side effect of removing a bunch of unwrap calls in client code, since unwrapping was the only safe way to use the Mirrored value we used to return. PiperOrigin-RevId: 215301909
Diffstat (limited to 'tensorflow/python/training')
-rw-r--r--tensorflow/python/training/distribute.py51
-rw-r--r--tensorflow/python/training/distribution_strategy_context.py2
-rw-r--r--tensorflow/python/training/optimizer.py10
3 files changed, 38 insertions, 25 deletions
diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py
index 419a9ec12b..a92a1bdee7 100644
--- a/tensorflow/python/training/distribute.py
+++ b/tensorflow/python/training/distribute.py
@@ -26,7 +26,6 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import resource_variable_ops
-from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.losses import losses_impl
from tensorflow.python.platform import tf_logging
@@ -807,15 +806,22 @@ class DistributionStrategy(object):
var: Variable, possibly mirrored to multiple devices, to operate on.
fn: Function to call. Should take the variable as the first argument.
*args: Additional positional arguments to pass to `fn()`.
- **kwargs: Keyword arguments to pass to `fn()`.
+ **kwargs: Keyword arguments to pass to `fn()`. If "grouped=False" is
+ specified, the return value will be unwrapped.
Returns:
- Merged return value of `fn` across all towers.
+ By default, the merged return value of `fn` across all towers. The merged
+ result has dependencies to make sure that if it is evaluated at all, the
+ side effects (updates) will happen on every tower. If instead
+ "grouped=False" is specified, this function will return a nest of lists
+ where each list has an element per tower, and the caller is responsible
+ for ensuring all elements are executed.
"""
_require_cross_tower_context(self)
- return self._update(var, fn, *args, **kwargs)
+ options = {"grouped": kwargs.pop("grouped", True)}
+ return self._update(var, options, fn, *args, **kwargs)
- def _update(self, var, fn, *args, **kwargs):
+ def _update(self, var, options, fn, *args, **kwargs):
raise NotImplementedError("must be implemented in descendants")
def update_non_slot(self, colocate_with, fn, *args, **kwargs):
@@ -825,15 +831,18 @@ class DistributionStrategy(object):
colocate_with: The return value of `non_slot_devices()`.
fn: Function to execute.
*args: Positional arguments to pass to `fn()`.
- **kwargs: Keyword arguments to pass to `fn()`.
+ **kwargs: Keyword arguments to pass to `fn()`. If "grouped=False" is
+ specified, the return value will be unwrapped and the caller is
+ responsible for ensuring all elements are executed.
Returns:
Return value of `fn`, possibly merged across devices.
"""
_require_cross_tower_context(self)
- return self._update_non_slot(colocate_with, fn, *args, **kwargs)
+ options = {"grouped": kwargs.pop("grouped", True)}
+ return self._update_non_slot(colocate_with, options, fn, *args, **kwargs)
- def _update_non_slot(self, colocate_with, fn, *args, **kwargs):
+ def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs):
raise NotImplementedError("must be implemented in descendants")
def unwrap(self, value):
@@ -1134,17 +1143,22 @@ class _DefaultDistributionStrategy(DistributionStrategy):
del aggregation, destinations
return value
- def _update(self, var, fn, *args, **kwargs):
- # TODO(josh11b): Figure out what we should be passing to UpdateContext()
- # once that value is used for something.
- with ops.colocate_with(var), UpdateContext(var):
- return fn(var, *args, **kwargs)
+ def _update(self, var, options, fn, *args, **kwargs):
+ # The implementations of _update() and _update_non_slot() are identical
+ # except _update() passes `var` as the first argument to `fn()`.
+ return self._update_non_slot(var, options, fn, var, *args, **kwargs)
- def _update_non_slot(self, colocate_with, fn, *args, **kwargs):
+ def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs):
+ should_group = options.pop("grouped")
+ assert not options # Validate that we are processing all of the options.
# TODO(josh11b): Figure out what we should be passing to UpdateContext()
# once that value is used for something.
with ops.colocate_with(colocate_with), UpdateContext(colocate_with):
- return fn(*args, **kwargs)
+ result = fn(*args, **kwargs)
+ if should_group:
+ return result
+ else:
+ return nest.map_structure(self._unwrap, result)
def read_var(self, tower_local_var):
return array_ops.identity(tower_local_var)
@@ -1193,13 +1207,10 @@ class _DefaultDistributionStrategy(DistributionStrategy):
def increment_var(v, amount=1):
"""`v += amount`, distributed-aware version."""
def update(vu):
- if isinstance(vu, resource_variable_ops.ResourceVariable):
- return vu.assign_add(amount, read_value=False)
- else:
- return state_ops.assign_add(vu, amount)
+ return vu.assign_add(amount, read_value=False)
def merge_fn(dist, vm):
- return dist.group(dist.update(vm, update))
+ return dist.update(vm, update)
tower_context = distribution_strategy_context.get_tower_context()
return tower_context.merge_call(merge_fn, v)
diff --git a/tensorflow/python/training/distribution_strategy_context.py b/tensorflow/python/training/distribution_strategy_context.py
index 998b5c35ce..ce580a406f 100644
--- a/tensorflow/python/training/distribution_strategy_context.py
+++ b/tensorflow/python/training/distribution_strategy_context.py
@@ -89,6 +89,7 @@ def get_tower_context():
"""Returns the current TowerContext or None if in a cross-tower context.
Note that execution:
+
1. starts in the default (single-tower) tower context (this function
will return the default TowerContext object);
2. switches to cross-tower context (in which case this will return
@@ -121,6 +122,7 @@ def get_cross_tower_context():
"""Returns the current DistributionStrategy if in a cross-tower context.
Note that execution:
+
1. starts in the default (single-tower) tower context;
2. switches to cross-tower context when entering a
`with DistributionStrategy.scope():` block;
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index 30b0ed20c8..47034919e1 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -692,7 +692,7 @@ class Optimizer(
update_ops = [
op
for grad, var in grads_and_vars
- for op in distribution.unwrap(distribution.update(var, update, grad))
+ for op in distribution.update(var, update, grad, grouped=False)
]
def finish(self, update_ops):
@@ -700,13 +700,13 @@ class Optimizer(
non_slot_devices = distribution.non_slot_devices(var_list)
finish_updates = distribution.update_non_slot(
- non_slot_devices, finish, self, update_ops)
+ non_slot_devices, finish, self, update_ops, grouped=False)
if global_step is None:
apply_updates = distribution.group(finish_updates, name=name)
else:
- with ops.control_dependencies(distribution.unwrap(finish_updates)):
- apply_updates = distribution.group(distribution.update(
- global_step, state_ops.assign_add, 1, name=name))
+ with ops.control_dependencies(finish_updates):
+ apply_updates = distribution.update(
+ global_step, state_ops.assign_add, 1, name=name)
if not context.executing_eagerly():
if isinstance(apply_updates, ops.Tensor):