aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-01 16:31:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-01 16:35:32 -0700
commitbb1f9e1a57c8bc18325b3c86298be96e6647a0a3 (patch)
tree66809f6708a49a5b8d408ef0b349fe837d905cda
parent49bbfec04b729960999ef054e3acab719631b101 (diff)
Change semantics of DistributionStrategy.update() to make sure the
output depends on the updates across all mirrors. Before this change, update() would return a Mirrored value that where each component was an update to a single mirror. This caused a problem since for reading purposes other DistributionStrategy methods would consider it okay to read any single component, and so if you for example did something like session.run(strategy.update(...)) it would only perform the update on one replica. The fix is to have the output be a Mirrored value that is actually the identity operation returning the output on that device, but that has a control dependency making sure that the update actually happens on all the replicas. This fix was already present in MirroredVariable._assign_func, this CL moves the fix into update() and generalizes it to multiple return values. To disable this new grouping behavior, you may now pass "grouped=False" to update(). For example, some callers (like Optimizer) are performing a lot of updates and they prefer to group all of them together at once for performance reasons. In this case, we still want to make sure the caller executes the update on all replicas, so we return an unwrapped value instead of a Mirrored value. This has the happy side effect of removing a bunch of unwrap calls in client code, since unwrapping was the only safe way to use the Mirrored value we used to return. PiperOrigin-RevId: 215301909
-rw-r--r--tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py3
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py12
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py2
-rw-r--r--tensorflow/contrib/distribute/python/one_device_strategy.py17
-rw-r--r--tensorflow/contrib/distribute/python/parameter_server_strategy.py22
-rw-r--r--tensorflow/contrib/distribute/python/parameter_server_strategy_test.py3
-rw-r--r--tensorflow/contrib/distribute/python/strategy_test_lib.py6
-rw-r--r--tensorflow/contrib/distribute/python/tpu_strategy.py36
-rw-r--r--tensorflow/contrib/distribute/python/values.py36
-rw-r--r--tensorflow/contrib/optimizer_v2/optimizer_v2.py32
-rw-r--r--tensorflow/python/training/distribute.py51
-rw-r--r--tensorflow/python/training/distribution_strategy_context.py2
-rw-r--r--tensorflow/python/training/optimizer.py10
13 files changed, 144 insertions, 88 deletions
diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
index 33ffbf6abe..6796a23d46 100644
--- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
@@ -128,7 +128,8 @@ class CollectiveAllReduceStrategyTestBase(
# TODO(yuefengz): support non-Mirrored variable as destinations.
g = d.reduce(
variable_scope.VariableAggregation.SUM, g, destinations=v)
- with ops.control_dependencies(d.unwrap(d.update(v, update, g))):
+ with ops.control_dependencies(
+ d.update(v, update, g, grouped=False)):
after_list.append(d.read_var(v))
return before_list, after_list
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index 4d7516063c..6bd380a22d 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -627,9 +627,11 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
return self._get_cross_tower_ops().batch_reduce(aggregation,
value_destination_pairs)
- def _update(self, var, fn, *args, **kwargs):
+ def _update(self, var, options, fn, *args, **kwargs):
# TODO(josh11b): In eager mode, use one thread per device.
assert isinstance(var, values.DistributedVariable)
+ should_group = options.pop("grouped")
+ assert not options # Validate that we are processing all of the options.
updates = {}
for d, v in var._index.items(): # pylint: disable=protected-access
name = "update_%d" % self._device_index.get(d)
@@ -638,10 +640,12 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
updates[d] = fn(v,
*values.select_device_mirrored(d, args),
**values.select_device_mirrored(d, kwargs))
- return values.regroup(updates, values.Mirrored)
+ return values.update_regroup(self, updates, should_group)
- def _update_non_slot(self, colocate_with, fn, *args, **kwargs):
+ def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs):
assert isinstance(colocate_with, list)
+ should_group = options.pop("grouped")
+ assert not options # Validate that we are processing all of the options.
# TODO(josh11b): In eager mode, use one thread per device.
updates = {}
for d in colocate_with:
@@ -649,7 +653,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
updates[d] = fn(*values.select_device_mirrored(d, args),
**values.select_device_mirrored(d, kwargs))
- return values.regroup(updates, values.Mirrored)
+ return values.update_regroup(self, updates, should_group)
def read_var(self, tower_local_var):
"""Read the aggregate value of a tower-local variable."""
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
index f51e543624..eeac528329 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
@@ -826,7 +826,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
with dist.scope():
ret_v_sum = dist.call_for_each_tower(model_fn, run_concurrently=False)
- update_ops = dist.unwrap(dist.update(ret_v_sum, update, 5.0))
+ update_ops = dist.update(ret_v_sum, update, 5.0, grouped=False)
# Initialize variables.
self.evaluate(variables.global_variables_initializer())
diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py
index 23b220f64b..f525919048 100644
--- a/tensorflow/contrib/distribute/python/one_device_strategy.py
+++ b/tensorflow/contrib/distribute/python/one_device_strategy.py
@@ -141,14 +141,21 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy):
else:
assert False
- def _update(self, var, fn, *args, **kwargs):
- with ops.device(self._device), distribute_lib.UpdateContext(self._device):
- return fn(var, *args, **kwargs)
+ def _update(self, var, options, fn, *args, **kwargs):
+ # The implementations of _update() and _update_non_slot() are identical
+ # except _update() passes `var` as the first argument to `fn()`.
+ return self._update_non_slot(var, options, fn, var, *args, **kwargs)
- def _update_non_slot(self, colocate_with, fn, *args, **kwargs):
+ def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs):
del colocate_with
+ should_group = options.pop("grouped")
+ assert not options # Validate that we are processing all of the options.
with ops.device(self._device), distribute_lib.UpdateContext(self._device):
- return fn(*args, **kwargs)
+ result = fn(*args, **kwargs)
+ if should_group:
+ return result
+ else:
+ return nest.map_structure(self._unwrap, result)
def read_var(self, tower_local_var):
"""Read the aggregate value of a tower-local variable."""
diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py
index 1125d027f6..6ddd91507b 100644
--- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py
+++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py
@@ -343,21 +343,33 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
return nest.map_structure(_select_fn, structured)
- def _update(self, var, fn, *args, **kwargs):
+ def _update(self, var, options, fn, *args, **kwargs):
if isinstance(var, values.AggregatingVariable):
var = var.get()
if not isinstance(var, resource_variable_ops.ResourceVariable):
raise ValueError(
"You can not update `var` %r. It must be a Variable." % var)
+ should_group = options.pop("grouped")
+ assert not options # Validate that we are processing all of the options.
with ops.colocate_with(var), distribute_lib.UpdateContext(var.device):
- return fn(var, *self._select_single_value(args),
- **self._select_single_value(kwargs))
+ result = fn(var, *self._select_single_value(args),
+ **self._select_single_value(kwargs))
+ if should_group:
+ return result
+ else:
+ return nest.map_structure(self._unwrap, result)
# TODO(yuefengz): does it need to call _select_single_value?
- def _update_non_slot(self, colocate_with, fn, *args, **kwargs):
+ def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs):
+ should_group = options.pop("grouped")
+ assert not options # Validate that we are processing all of the options.
with ops.device(
colocate_with.device), distribute_lib.UpdateContext(colocate_with):
- return fn(*args, **kwargs)
+ result = fn(*args, **kwargs)
+ if should_group:
+ return result
+ else:
+ return nest.map_structure(self._unwrap, result)
def _unwrap(self, val):
if isinstance(val, values.DistributedValues):
diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
index 12789e0bc9..353d11a583 100644
--- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
@@ -395,7 +395,8 @@ class ParameterServerStrategyTestBase(
# TODO(yuefengz): support non-Mirrored variable as destinations.
g = d.reduce(
variable_scope.VariableAggregation.SUM, g, destinations=v)
- with ops.control_dependencies(d.unwrap(d.update(v, update, g))):
+ with ops.control_dependencies(
+ d.update(v, update, g, grouped=False)):
after_list.append(d.read_var(v))
return before_list, after_list
diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py
index 5d498fb629..fd280f5754 100644
--- a/tensorflow/contrib/distribute/python/strategy_test_lib.py
+++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py
@@ -115,7 +115,8 @@ class DistributionTestBase(test.TestCase):
with ops.control_dependencies([fetched]):
g = d.reduce(
variable_scope.VariableAggregation.SUM, g, destinations=v)
- with ops.control_dependencies(d.unwrap(d.update(v, update, g))):
+ with ops.control_dependencies(d.update(
+ v, update, g, grouped=False)):
after_list.append(d.read_var(v))
return before_list, after_list
@@ -169,7 +170,8 @@ class DistributionTestBase(test.TestCase):
with ops.control_dependencies([fetched]):
g = d.reduce(
variable_scope.VariableAggregation.SUM, g, destinations=v)
- with ops.control_dependencies(d.unwrap(d.update(v, update, g))):
+ with ops.control_dependencies(d.update(
+ v, update, g, grouped=False)):
after_list.append(d.read_var(v))
return before_list, after_list
diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py
index 1b555482d3..c3c7df3cd8 100644
--- a/tensorflow/contrib/distribute/python/tpu_strategy.py
+++ b/tensorflow/contrib/distribute/python/tpu_strategy.py
@@ -297,6 +297,7 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
# For outputs that have already been aggregated, take the first value
# from the list as each value should be the same. Else return the full
# list of values.
+ # TODO(josh11b): If aggregation is NONE, we should return a PerDevice value.
if aggregation is not variables_lib.VariableAggregation.NONE:
# TODO(priyag): Should this return the element or a list with 1 element
last_step_tensor_outputs_dict[name] = output[0]
@@ -398,11 +399,16 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
return output * (1. / len(value))
return output
- def _update(self, var, fn, *args, **kwargs):
- # TODO(jhseu): Consider supporting grouped==False.
+ def _update(self, var, options, fn, *args, **kwargs):
assert isinstance(var, values.TPUMirroredVariable)
+ should_group = options.pop("grouped")
+ assert not options # Validate that we are processing all of the options.
+
if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access
- return fn(var, *args, **kwargs)
+ if should_group:
+ return fn(var, *args, **kwargs)
+ else:
+ return [fn(var, *args, **kwargs)]
# Otherwise, we revert to MirroredStrategy behavior and update each variable
# directly.
@@ -414,23 +420,25 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
updates[d] = fn(v,
*values.select_device_mirrored(d, args),
**values.select_device_mirrored(d, kwargs))
+ return values.update_regroup(self, updates, should_group)
- # Make a single control dependency to keep the variables mirrored. If one
- # assignment is fetched, then run all assignments.
- sorted_keys = sorted(updates.keys())
- update_tuple = control_flow_ops.tuple([updates[d] for d in sorted_keys])
- for i, d in enumerate(sorted_keys):
- updates[d] = update_tuple[i]
- return values.regroup(updates, values.Mirrored)
+ # TODO(josh11b): Need to implement _update_non_slot()!
def read_var(self, var):
assert isinstance(var, values.TPUMirroredVariable)
return var.read_value()
- def _unwrap(self, value):
- if isinstance(value, list):
- return value
- return [value]
+ def _unwrap(self, val):
+ if isinstance(val, values.DistributedValues):
+ # Return in a deterministic order.
+ return [val.get(device=d) for d in sorted(val.devices)]
+ elif isinstance(val, list):
+ # TODO(josh11b): We need to remove this case; per device values should
+ # be represented using a PerDevice wrapper instead of a list with
+ # one entry per device.
+ return val
+ return [val]
+
@property
def num_towers(self):
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index c18faeb67d..18ceba42c2 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -366,18 +366,7 @@ class MirroredVariable(DistributedVariable, Mirrored,
# We are calling assign on the mirrored variable in cross tower context,
# use update to update the variable.
strategy = distribution_strategy_context.get_distribution_strategy()
- updates = strategy.update(self, f, *args, **kwargs)
- grouped = strategy.group(updates)
- if isinstance(updates, DistributedValues) and updates.is_tensor_like:
- # Make sure we run all updates. Without this, something like
- # session.run(mirrored_var.assign*(...)) may only update one tower.
- index = {}
- for d in updates.devices:
- with ops.device(d), ops.control_dependencies([grouped]):
- index[d] = array_ops.identity(updates.get(d))
- return Mirrored(index)
- else:
- return grouped
+ return strategy.update(self, f, *args, **kwargs)
else:
_assert_tower_context()
# We are calling an assign function on the mirrored variable in tower
@@ -1049,6 +1038,29 @@ def select_device_mirrored(device, structured):
return nest.map_structure(_get_mirrored, structured)
+def update_regroup(strategy, updates, should_group):
+ """Regroup for an update, with dependencies to ensure all updates execute."""
+ regrouped = regroup(updates, Mirrored)
+ if not should_group:
+ return nest.map_structure(strategy.unwrap, regrouped)
+ grouped_flat = []
+ for u in nest.flatten(regrouped):
+ if isinstance(u, DistributedValues):
+ g = strategy.group(u)
+ if u.is_tensor_like:
+ # Make sure we run all updates. Without this, something like
+ # session.run(strategy.update(...)) may only update one tower.
+ index = {}
+ for d in u.devices:
+ with ops.device(d), ops.control_dependencies([g]):
+ index[d] = array_ops.identity(u.get(d))
+ g = Mirrored(index)
+ else:
+ g = u
+ grouped_flat.append(g)
+ return nest.pack_sequence_as(regrouped, grouped_flat)
+
+
class PerDeviceDataIterator(object):
"""An iterator (like `tf.data.Iterator`) into a `PerDeviceDataset`."""
diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
index 6af59dcfbf..53e27c08c4 100644
--- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py
+++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
@@ -30,7 +30,6 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
-from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import distribute as distribute_lib
@@ -965,8 +964,7 @@ class OptimizerV2(optimizer_v1.Optimizer):
# Use the processors to update the variables.
update_ops = []
for grad, var in grads_and_vars:
- update_ops.extend(distribution.unwrap(distribution.update(
- var, update, grad)))
+ update_ops.extend(distribution.update(var, update, grad, grouped=False))
# Give the child class a chance to do something after applying
# gradients
@@ -978,26 +976,24 @@ class OptimizerV2(optimizer_v1.Optimizer):
update_ops = control_flow_ops.group(update_ops)
with ops.control_dependencies([update_ops]):
- finish_updates = distribution.update_non_slot(non_slot_devices, finish)
- if finish_updates is None:
- finish_updates = update_ops
+ finish_updates = distribution.update_non_slot(
+ non_slot_devices, finish, grouped=False)
+ # We said grouped=False, which means finish_updates is always a list.
+ # It will be [None] when finish() returns None.
+ if finish_updates == [None]:
+ finish_updates = [update_ops]
# Update `global_step` (if any).
if global_step is None:
apply_updates = distribution.group(finish_updates, name=name)
else:
- with ops.control_dependencies(distribution.unwrap(finish_updates)):
-
- def update_global_step(global_step):
- if isinstance(global_step, resource_variable_ops.ResourceVariable):
- return global_step.assign_add(
- ops.convert_to_tensor(1, dtype=global_step.dtype),
- read_value=False)
- else:
- return state_ops.assign_add(global_step, 1)
-
- apply_updates = distribution.group(
- distribution.update(global_step, update_global_step), name=name)
+ with ops.control_dependencies(finish_updates):
+
+ def update_global_step(global_step, name):
+ return global_step.assign_add(1, read_value=False, name=name)
+
+ apply_updates = distribution.update(
+ global_step, update_global_step, name)
# Add the training op to the TRAIN_OP graph collection in graph mode.
if not eager_execution:
diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py
index 419a9ec12b..a92a1bdee7 100644
--- a/tensorflow/python/training/distribute.py
+++ b/tensorflow/python/training/distribute.py
@@ -26,7 +26,6 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import resource_variable_ops
-from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.losses import losses_impl
from tensorflow.python.platform import tf_logging
@@ -807,15 +806,22 @@ class DistributionStrategy(object):
var: Variable, possibly mirrored to multiple devices, to operate on.
fn: Function to call. Should take the variable as the first argument.
*args: Additional positional arguments to pass to `fn()`.
- **kwargs: Keyword arguments to pass to `fn()`.
+ **kwargs: Keyword arguments to pass to `fn()`. If "grouped=False" is
+ specified, the return value will be unwrapped.
Returns:
- Merged return value of `fn` across all towers.
+ By default, the merged return value of `fn` across all towers. The merged
+ result has dependencies to make sure that if it is evaluated at all, the
+ side effects (updates) will happen on every tower. If instead
+ "grouped=False" is specified, this function will return a nest of lists
+ where each list has an element per tower, and the caller is responsible
+ for ensuring all elements are executed.
"""
_require_cross_tower_context(self)
- return self._update(var, fn, *args, **kwargs)
+ options = {"grouped": kwargs.pop("grouped", True)}
+ return self._update(var, options, fn, *args, **kwargs)
- def _update(self, var, fn, *args, **kwargs):
+ def _update(self, var, options, fn, *args, **kwargs):
raise NotImplementedError("must be implemented in descendants")
def update_non_slot(self, colocate_with, fn, *args, **kwargs):
@@ -825,15 +831,18 @@ class DistributionStrategy(object):
colocate_with: The return value of `non_slot_devices()`.
fn: Function to execute.
*args: Positional arguments to pass to `fn()`.
- **kwargs: Keyword arguments to pass to `fn()`.
+ **kwargs: Keyword arguments to pass to `fn()`. If "grouped=False" is
+ specified, the return value will be unwrapped and the caller is
+ responsible for ensuring all elements are executed.
Returns:
Return value of `fn`, possibly merged across devices.
"""
_require_cross_tower_context(self)
- return self._update_non_slot(colocate_with, fn, *args, **kwargs)
+ options = {"grouped": kwargs.pop("grouped", True)}
+ return self._update_non_slot(colocate_with, options, fn, *args, **kwargs)
- def _update_non_slot(self, colocate_with, fn, *args, **kwargs):
+ def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs):
raise NotImplementedError("must be implemented in descendants")
def unwrap(self, value):
@@ -1134,17 +1143,22 @@ class _DefaultDistributionStrategy(DistributionStrategy):
del aggregation, destinations
return value
- def _update(self, var, fn, *args, **kwargs):
- # TODO(josh11b): Figure out what we should be passing to UpdateContext()
- # once that value is used for something.
- with ops.colocate_with(var), UpdateContext(var):
- return fn(var, *args, **kwargs)
+ def _update(self, var, options, fn, *args, **kwargs):
+ # The implementations of _update() and _update_non_slot() are identical
+ # except _update() passes `var` as the first argument to `fn()`.
+ return self._update_non_slot(var, options, fn, var, *args, **kwargs)
- def _update_non_slot(self, colocate_with, fn, *args, **kwargs):
+ def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs):
+ should_group = options.pop("grouped")
+ assert not options # Validate that we are processing all of the options.
# TODO(josh11b): Figure out what we should be passing to UpdateContext()
# once that value is used for something.
with ops.colocate_with(colocate_with), UpdateContext(colocate_with):
- return fn(*args, **kwargs)
+ result = fn(*args, **kwargs)
+ if should_group:
+ return result
+ else:
+ return nest.map_structure(self._unwrap, result)
def read_var(self, tower_local_var):
return array_ops.identity(tower_local_var)
@@ -1193,13 +1207,10 @@ class _DefaultDistributionStrategy(DistributionStrategy):
def increment_var(v, amount=1):
"""`v += amount`, distributed-aware version."""
def update(vu):
- if isinstance(vu, resource_variable_ops.ResourceVariable):
- return vu.assign_add(amount, read_value=False)
- else:
- return state_ops.assign_add(vu, amount)
+ return vu.assign_add(amount, read_value=False)
def merge_fn(dist, vm):
- return dist.group(dist.update(vm, update))
+ return dist.update(vm, update)
tower_context = distribution_strategy_context.get_tower_context()
return tower_context.merge_call(merge_fn, v)
diff --git a/tensorflow/python/training/distribution_strategy_context.py b/tensorflow/python/training/distribution_strategy_context.py
index 998b5c35ce..ce580a406f 100644
--- a/tensorflow/python/training/distribution_strategy_context.py
+++ b/tensorflow/python/training/distribution_strategy_context.py
@@ -89,6 +89,7 @@ def get_tower_context():
"""Returns the current TowerContext or None if in a cross-tower context.
Note that execution:
+
1. starts in the default (single-tower) tower context (this function
will return the default TowerContext object);
2. switches to cross-tower context (in which case this will return
@@ -121,6 +122,7 @@ def get_cross_tower_context():
"""Returns the current DistributionStrategy if in a cross-tower context.
Note that execution:
+
1. starts in the default (single-tower) tower context;
2. switches to cross-tower context when entering a
`with DistributionStrategy.scope():` block;
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index 30b0ed20c8..47034919e1 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -692,7 +692,7 @@ class Optimizer(
update_ops = [
op
for grad, var in grads_and_vars
- for op in distribution.unwrap(distribution.update(var, update, grad))
+ for op in distribution.update(var, update, grad, grouped=False)
]
def finish(self, update_ops):
@@ -700,13 +700,13 @@ class Optimizer(
non_slot_devices = distribution.non_slot_devices(var_list)
finish_updates = distribution.update_non_slot(
- non_slot_devices, finish, self, update_ops)
+ non_slot_devices, finish, self, update_ops, grouped=False)
if global_step is None:
apply_updates = distribution.group(finish_updates, name=name)
else:
- with ops.control_dependencies(distribution.unwrap(finish_updates)):
- apply_updates = distribution.group(distribution.update(
- global_step, state_ops.assign_add, 1, name=name))
+ with ops.control_dependencies(finish_updates):
+ apply_updates = distribution.update(
+ global_step, state_ops.assign_add, 1, name=name)
if not context.executing_eagerly():
if isinstance(apply_updates, ops.Tensor):