aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py3
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py12
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py2
-rw-r--r--tensorflow/contrib/distribute/python/one_device_strategy.py17
-rw-r--r--tensorflow/contrib/distribute/python/parameter_server_strategy.py22
-rw-r--r--tensorflow/contrib/distribute/python/parameter_server_strategy_test.py3
-rw-r--r--tensorflow/contrib/distribute/python/strategy_test_lib.py6
-rw-r--r--tensorflow/contrib/distribute/python/tpu_strategy.py36
-rw-r--r--tensorflow/contrib/distribute/python/values.py36
-rw-r--r--tensorflow/contrib/optimizer_v2/optimizer_v2.py32
-rw-r--r--tensorflow/python/training/distribute.py51
-rw-r--r--tensorflow/python/training/distribution_strategy_context.py2
-rw-r--r--tensorflow/python/training/optimizer.py10
13 files changed, 144 insertions, 88 deletions
diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
index 33ffbf6abe..6796a23d46 100644
--- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
@@ -128,7 +128,8 @@ class CollectiveAllReduceStrategyTestBase(
# TODO(yuefengz): support non-Mirrored variable as destinations.
g = d.reduce(
variable_scope.VariableAggregation.SUM, g, destinations=v)
- with ops.control_dependencies(d.unwrap(d.update(v, update, g))):
+ with ops.control_dependencies(
+ d.update(v, update, g, grouped=False)):
after_list.append(d.read_var(v))
return before_list, after_list
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index 4d7516063c..6bd380a22d 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -627,9 +627,11 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
return self._get_cross_tower_ops().batch_reduce(aggregation,
value_destination_pairs)
- def _update(self, var, fn, *args, **kwargs):
+ def _update(self, var, options, fn, *args, **kwargs):
# TODO(josh11b): In eager mode, use one thread per device.
assert isinstance(var, values.DistributedVariable)
+ should_group = options.pop("grouped")
+ assert not options # Validate that we are processing all of the options.
updates = {}
for d, v in var._index.items(): # pylint: disable=protected-access
name = "update_%d" % self._device_index.get(d)
@@ -638,10 +640,12 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
updates[d] = fn(v,
*values.select_device_mirrored(d, args),
**values.select_device_mirrored(d, kwargs))
- return values.regroup(updates, values.Mirrored)
+ return values.update_regroup(self, updates, should_group)
- def _update_non_slot(self, colocate_with, fn, *args, **kwargs):
+ def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs):
assert isinstance(colocate_with, list)
+ should_group = options.pop("grouped")
+ assert not options # Validate that we are processing all of the options.
# TODO(josh11b): In eager mode, use one thread per device.
updates = {}
for d in colocate_with:
@@ -649,7 +653,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
updates[d] = fn(*values.select_device_mirrored(d, args),
**values.select_device_mirrored(d, kwargs))
- return values.regroup(updates, values.Mirrored)
+ return values.update_regroup(self, updates, should_group)
def read_var(self, tower_local_var):
"""Read the aggregate value of a tower-local variable."""
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
index f51e543624..eeac528329 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
@@ -826,7 +826,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
with dist.scope():
ret_v_sum = dist.call_for_each_tower(model_fn, run_concurrently=False)
- update_ops = dist.unwrap(dist.update(ret_v_sum, update, 5.0))
+ update_ops = dist.update(ret_v_sum, update, 5.0, grouped=False)
# Initialize variables.
self.evaluate(variables.global_variables_initializer())
diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py
index 23b220f64b..f525919048 100644
--- a/tensorflow/contrib/distribute/python/one_device_strategy.py
+++ b/tensorflow/contrib/distribute/python/one_device_strategy.py
@@ -141,14 +141,21 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy):
else:
assert False
- def _update(self, var, fn, *args, **kwargs):
- with ops.device(self._device), distribute_lib.UpdateContext(self._device):
- return fn(var, *args, **kwargs)
+ def _update(self, var, options, fn, *args, **kwargs):
+ # The implementations of _update() and _update_non_slot() are identical
+ # except _update() passes `var` as the first argument to `fn()`.
+ return self._update_non_slot(var, options, fn, var, *args, **kwargs)
- def _update_non_slot(self, colocate_with, fn, *args, **kwargs):
+ def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs):
del colocate_with
+ should_group = options.pop("grouped")
+ assert not options # Validate that we are processing all of the options.
with ops.device(self._device), distribute_lib.UpdateContext(self._device):
- return fn(*args, **kwargs)
+ result = fn(*args, **kwargs)
+ if should_group:
+ return result
+ else:
+ return nest.map_structure(self._unwrap, result)
def read_var(self, tower_local_var):
"""Read the aggregate value of a tower-local variable."""
diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py
index 1125d027f6..6ddd91507b 100644
--- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py
+++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py
@@ -343,21 +343,33 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
return nest.map_structure(_select_fn, structured)
- def _update(self, var, fn, *args, **kwargs):
+ def _update(self, var, options, fn, *args, **kwargs):
if isinstance(var, values.AggregatingVariable):
var = var.get()
if not isinstance(var, resource_variable_ops.ResourceVariable):
raise ValueError(
"You can not update `var` %r. It must be a Variable." % var)
+ should_group = options.pop("grouped")
+ assert not options # Validate that we are processing all of the options.
with ops.colocate_with(var), distribute_lib.UpdateContext(var.device):
- return fn(var, *self._select_single_value(args),
- **self._select_single_value(kwargs))
+ result = fn(var, *self._select_single_value(args),
+ **self._select_single_value(kwargs))
+ if should_group:
+ return result
+ else:
+ return nest.map_structure(self._unwrap, result)
# TODO(yuefengz): does it need to call _select_single_value?
- def _update_non_slot(self, colocate_with, fn, *args, **kwargs):
+ def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs):
+ should_group = options.pop("grouped")
+ assert not options # Validate that we are processing all of the options.
with ops.device(
colocate_with.device), distribute_lib.UpdateContext(colocate_with):
- return fn(*args, **kwargs)
+ result = fn(*args, **kwargs)
+ if should_group:
+ return result
+ else:
+ return nest.map_structure(self._unwrap, result)
def _unwrap(self, val):
if isinstance(val, values.DistributedValues):
diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
index 12789e0bc9..353d11a583 100644
--- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
@@ -395,7 +395,8 @@ class ParameterServerStrategyTestBase(
# TODO(yuefengz): support non-Mirrored variable as destinations.
g = d.reduce(
variable_scope.VariableAggregation.SUM, g, destinations=v)
- with ops.control_dependencies(d.unwrap(d.update(v, update, g))):
+ with ops.control_dependencies(
+ d.update(v, update, g, grouped=False)):
after_list.append(d.read_var(v))
return before_list, after_list
diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py
index 5d498fb629..fd280f5754 100644
--- a/tensorflow/contrib/distribute/python/strategy_test_lib.py
+++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py
@@ -115,7 +115,8 @@ class DistributionTestBase(test.TestCase):
with ops.control_dependencies([fetched]):
g = d.reduce(
variable_scope.VariableAggregation.SUM, g, destinations=v)
- with ops.control_dependencies(d.unwrap(d.update(v, update, g))):
+ with ops.control_dependencies(d.update(
+ v, update, g, grouped=False)):
after_list.append(d.read_var(v))
return before_list, after_list
@@ -169,7 +170,8 @@ class DistributionTestBase(test.TestCase):
with ops.control_dependencies([fetched]):
g = d.reduce(
variable_scope.VariableAggregation.SUM, g, destinations=v)
- with ops.control_dependencies(d.unwrap(d.update(v, update, g))):
+ with ops.control_dependencies(d.update(
+ v, update, g, grouped=False)):
after_list.append(d.read_var(v))
return before_list, after_list
diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py
index 1b555482d3..c3c7df3cd8 100644
--- a/tensorflow/contrib/distribute/python/tpu_strategy.py
+++ b/tensorflow/contrib/distribute/python/tpu_strategy.py
@@ -297,6 +297,7 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
# For outputs that have already been aggregated, take the first value
# from the list as each value should be the same. Else return the full
# list of values.
+ # TODO(josh11b): If aggregation is NONE, we should return a PerDevice value.
if aggregation is not variables_lib.VariableAggregation.NONE:
# TODO(priyag): Should this return the element or a list with 1 element
last_step_tensor_outputs_dict[name] = output[0]
@@ -398,11 +399,16 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
return output * (1. / len(value))
return output
- def _update(self, var, fn, *args, **kwargs):
- # TODO(jhseu): Consider supporting grouped==False.
+ def _update(self, var, options, fn, *args, **kwargs):
assert isinstance(var, values.TPUMirroredVariable)
+ should_group = options.pop("grouped")
+ assert not options # Validate that we are processing all of the options.
+
if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access
- return fn(var, *args, **kwargs)
+ if should_group:
+ return fn(var, *args, **kwargs)
+ else:
+ return [fn(var, *args, **kwargs)]
# Otherwise, we revert to MirroredStrategy behavior and update each variable
# directly.
@@ -414,23 +420,25 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
updates[d] = fn(v,
*values.select_device_mirrored(d, args),
**values.select_device_mirrored(d, kwargs))
+ return values.update_regroup(self, updates, should_group)
- # Make a single control dependency to keep the variables mirrored. If one
- # assignment is fetched, then run all assignments.
- sorted_keys = sorted(updates.keys())
- update_tuple = control_flow_ops.tuple([updates[d] for d in sorted_keys])
- for i, d in enumerate(sorted_keys):
- updates[d] = update_tuple[i]
- return values.regroup(updates, values.Mirrored)
+ # TODO(josh11b): Need to implement _update_non_slot()!
def read_var(self, var):
assert isinstance(var, values.TPUMirroredVariable)
return var.read_value()
- def _unwrap(self, value):
- if isinstance(value, list):
- return value
- return [value]
+ def _unwrap(self, val):
+ if isinstance(val, values.DistributedValues):
+ # Return in a deterministic order.
+ return [val.get(device=d) for d in sorted(val.devices)]
+ elif isinstance(val, list):
+ # TODO(josh11b): We need to remove this case; per device values should
+ # be represented using a PerDevice wrapper instead of a list with
+ # one entry per device.
+ return val
+ return [val]
+
@property
def num_towers(self):
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index c18faeb67d..18ceba42c2 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -366,18 +366,7 @@ class MirroredVariable(DistributedVariable, Mirrored,
# We are calling assign on the mirrored variable in cross tower context,
# use update to update the variable.
strategy = distribution_strategy_context.get_distribution_strategy()
- updates = strategy.update(self, f, *args, **kwargs)
- grouped = strategy.group(updates)
- if isinstance(updates, DistributedValues) and updates.is_tensor_like:
- # Make sure we run all updates. Without this, something like
- # session.run(mirrored_var.assign*(...)) may only update one tower.
- index = {}
- for d in updates.devices:
- with ops.device(d), ops.control_dependencies([grouped]):
- index[d] = array_ops.identity(updates.get(d))
- return Mirrored(index)
- else:
- return grouped
+ return strategy.update(self, f, *args, **kwargs)
else:
_assert_tower_context()
# We are calling an assign function on the mirrored variable in tower
@@ -1049,6 +1038,29 @@ def select_device_mirrored(device, structured):
return nest.map_structure(_get_mirrored, structured)
+def update_regroup(strategy, updates, should_group):
+ """Regroup for an update, with dependencies to ensure all updates execute."""
+ regrouped = regroup(updates, Mirrored)
+ if not should_group:
+ return nest.map_structure(strategy.unwrap, regrouped)
+ grouped_flat = []
+ for u in nest.flatten(regrouped):
+ if isinstance(u, DistributedValues):
+ g = strategy.group(u)
+ if u.is_tensor_like:
+ # Make sure we run all updates. Without this, something like
+ # session.run(strategy.update(...)) may only update one tower.
+ index = {}
+ for d in u.devices:
+ with ops.device(d), ops.control_dependencies([g]):
+ index[d] = array_ops.identity(u.get(d))
+ g = Mirrored(index)
+ else:
+ g = u
+ grouped_flat.append(g)
+ return nest.pack_sequence_as(regrouped, grouped_flat)
+
+
class PerDeviceDataIterator(object):
"""An iterator (like `tf.data.Iterator`) into a `PerDeviceDataset`."""
diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
index 6af59dcfbf..53e27c08c4 100644
--- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py
+++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
@@ -30,7 +30,6 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
-from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import distribute as distribute_lib
@@ -965,8 +964,7 @@ class OptimizerV2(optimizer_v1.Optimizer):
# Use the processors to update the variables.
update_ops = []
for grad, var in grads_and_vars:
- update_ops.extend(distribution.unwrap(distribution.update(
- var, update, grad)))
+ update_ops.extend(distribution.update(var, update, grad, grouped=False))
# Give the child class a chance to do something after applying
# gradients
@@ -978,26 +976,24 @@ class OptimizerV2(optimizer_v1.Optimizer):
update_ops = control_flow_ops.group(update_ops)
with ops.control_dependencies([update_ops]):
- finish_updates = distribution.update_non_slot(non_slot_devices, finish)
- if finish_updates is None:
- finish_updates = update_ops
+ finish_updates = distribution.update_non_slot(
+ non_slot_devices, finish, grouped=False)
+ # We said grouped=False, which means finish_updates is always a list.
+ # It will be [None] when finish() returns None.
+ if finish_updates == [None]:
+ finish_updates = [update_ops]
# Update `global_step` (if any).
if global_step is None:
apply_updates = distribution.group(finish_updates, name=name)
else:
- with ops.control_dependencies(distribution.unwrap(finish_updates)):
-
- def update_global_step(global_step):
- if isinstance(global_step, resource_variable_ops.ResourceVariable):
- return global_step.assign_add(
- ops.convert_to_tensor(1, dtype=global_step.dtype),
- read_value=False)
- else:
- return state_ops.assign_add(global_step, 1)
-
- apply_updates = distribution.group(
- distribution.update(global_step, update_global_step), name=name)
+ with ops.control_dependencies(finish_updates):
+
+ def update_global_step(global_step, name):
+ return global_step.assign_add(1, read_value=False, name=name)
+
+ apply_updates = distribution.update(
+ global_step, update_global_step, name)
# Add the training op to the TRAIN_OP graph collection in graph mode.
if not eager_execution:
diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py
index 419a9ec12b..a92a1bdee7 100644
--- a/tensorflow/python/training/distribute.py
+++ b/tensorflow/python/training/distribute.py
@@ -26,7 +26,6 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import resource_variable_ops
-from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.losses import losses_impl
from tensorflow.python.platform import tf_logging
@@ -807,15 +806,22 @@ class DistributionStrategy(object):
var: Variable, possibly mirrored to multiple devices, to operate on.
fn: Function to call. Should take the variable as the first argument.
*args: Additional positional arguments to pass to `fn()`.
- **kwargs: Keyword arguments to pass to `fn()`.
+ **kwargs: Keyword arguments to pass to `fn()`. If "grouped=False" is
+ specified, the return value will be unwrapped.
Returns:
- Merged return value of `fn` across all towers.
+ By default, the merged return value of `fn` across all towers. The merged
+ result has dependencies to make sure that if it is evaluated at all, the
+ side effects (updates) will happen on every tower. If instead
+ "grouped=False" is specified, this function will return a nest of lists
+ where each list has an element per tower, and the caller is responsible
+ for ensuring all elements are executed.
"""
_require_cross_tower_context(self)
- return self._update(var, fn, *args, **kwargs)
+ options = {"grouped": kwargs.pop("grouped", True)}
+ return self._update(var, options, fn, *args, **kwargs)
- def _update(self, var, fn, *args, **kwargs):
+ def _update(self, var, options, fn, *args, **kwargs):
raise NotImplementedError("must be implemented in descendants")
def update_non_slot(self, colocate_with, fn, *args, **kwargs):
@@ -825,15 +831,18 @@ class DistributionStrategy(object):
colocate_with: The return value of `non_slot_devices()`.
fn: Function to execute.
*args: Positional arguments to pass to `fn()`.
- **kwargs: Keyword arguments to pass to `fn()`.
+ **kwargs: Keyword arguments to pass to `fn()`. If "grouped=False" is
+ specified, the return value will be unwrapped and the caller is
+ responsible for ensuring all elements are executed.
Returns:
Return value of `fn`, possibly merged across devices.
"""
_require_cross_tower_context(self)
- return self._update_non_slot(colocate_with, fn, *args, **kwargs)
+ options = {"grouped": kwargs.pop("grouped", True)}
+ return self._update_non_slot(colocate_with, options, fn, *args, **kwargs)
- def _update_non_slot(self, colocate_with, fn, *args, **kwargs):
+ def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs):
raise NotImplementedError("must be implemented in descendants")
def unwrap(self, value):
@@ -1134,17 +1143,22 @@ class _DefaultDistributionStrategy(DistributionStrategy):
del aggregation, destinations
return value
- def _update(self, var, fn, *args, **kwargs):
- # TODO(josh11b): Figure out what we should be passing to UpdateContext()
- # once that value is used for something.
- with ops.colocate_with(var), UpdateContext(var):
- return fn(var, *args, **kwargs)
+ def _update(self, var, options, fn, *args, **kwargs):
+ # The implementations of _update() and _update_non_slot() are identical
+ # except _update() passes `var` as the first argument to `fn()`.
+ return self._update_non_slot(var, options, fn, var, *args, **kwargs)
- def _update_non_slot(self, colocate_with, fn, *args, **kwargs):
+ def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs):
+ should_group = options.pop("grouped")
+ assert not options # Validate that we are processing all of the options.
# TODO(josh11b): Figure out what we should be passing to UpdateContext()
# once that value is used for something.
with ops.colocate_with(colocate_with), UpdateContext(colocate_with):
- return fn(*args, **kwargs)
+ result = fn(*args, **kwargs)
+ if should_group:
+ return result
+ else:
+ return nest.map_structure(self._unwrap, result)
def read_var(self, tower_local_var):
return array_ops.identity(tower_local_var)
@@ -1193,13 +1207,10 @@ class _DefaultDistributionStrategy(DistributionStrategy):
def increment_var(v, amount=1):
"""`v += amount`, distributed-aware version."""
def update(vu):
- if isinstance(vu, resource_variable_ops.ResourceVariable):
- return vu.assign_add(amount, read_value=False)
- else:
- return state_ops.assign_add(vu, amount)
+ return vu.assign_add(amount, read_value=False)
def merge_fn(dist, vm):
- return dist.group(dist.update(vm, update))
+ return dist.update(vm, update)
tower_context = distribution_strategy_context.get_tower_context()
return tower_context.merge_call(merge_fn, v)
diff --git a/tensorflow/python/training/distribution_strategy_context.py b/tensorflow/python/training/distribution_strategy_context.py
index 998b5c35ce..ce580a406f 100644
--- a/tensorflow/python/training/distribution_strategy_context.py
+++ b/tensorflow/python/training/distribution_strategy_context.py
@@ -89,6 +89,7 @@ def get_tower_context():
"""Returns the current TowerContext or None if in a cross-tower context.
Note that execution:
+
1. starts in the default (single-tower) tower context (this function
will return the default TowerContext object);
2. switches to cross-tower context (in which case this will return
@@ -121,6 +122,7 @@ def get_cross_tower_context():
"""Returns the current DistributionStrategy if in a cross-tower context.
Note that execution:
+
1. starts in the default (single-tower) tower context;
2. switches to cross-tower context when entering a
`with DistributionStrategy.scope():` block;
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index 30b0ed20c8..47034919e1 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -692,7 +692,7 @@ class Optimizer(
update_ops = [
op
for grad, var in grads_and_vars
- for op in distribution.unwrap(distribution.update(var, update, grad))
+ for op in distribution.update(var, update, grad, grouped=False)
]
def finish(self, update_ops):
@@ -700,13 +700,13 @@ class Optimizer(
non_slot_devices = distribution.non_slot_devices(var_list)
finish_updates = distribution.update_non_slot(
- non_slot_devices, finish, self, update_ops)
+ non_slot_devices, finish, self, update_ops, grouped=False)
if global_step is None:
apply_updates = distribution.group(finish_updates, name=name)
else:
- with ops.control_dependencies(distribution.unwrap(finish_updates)):
- apply_updates = distribution.group(distribution.update(
- global_step, state_ops.assign_add, 1, name=name))
+ with ops.control_dependencies(finish_updates):
+ apply_updates = distribution.update(
+ global_step, state_ops.assign_add, 1, name=name)
if not context.executing_eagerly():
if isinstance(apply_updates, ops.Tensor):