diff options
13 files changed, 144 insertions, 88 deletions
diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py index 33ffbf6abe..6796a23d46 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py @@ -128,7 +128,8 @@ class CollectiveAllReduceStrategyTestBase( # TODO(yuefengz): support non-Mirrored variable as destinations. g = d.reduce( variable_scope.VariableAggregation.SUM, g, destinations=v) - with ops.control_dependencies(d.unwrap(d.update(v, update, g))): + with ops.control_dependencies( + d.update(v, update, g, grouped=False)): after_list.append(d.read_var(v)) return before_list, after_list diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 4d7516063c..6bd380a22d 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -627,9 +627,11 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): return self._get_cross_tower_ops().batch_reduce(aggregation, value_destination_pairs) - def _update(self, var, fn, *args, **kwargs): + def _update(self, var, options, fn, *args, **kwargs): # TODO(josh11b): In eager mode, use one thread per device. assert isinstance(var, values.DistributedVariable) + should_group = options.pop("grouped") + assert not options # Validate that we are processing all of the options. updates = {} for d, v in var._index.items(): # pylint: disable=protected-access name = "update_%d" % self._device_index.get(d) @@ -638,10 +640,12 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): updates[d] = fn(v, *values.select_device_mirrored(d, args), **values.select_device_mirrored(d, kwargs)) - return values.regroup(updates, values.Mirrored) + return values.update_regroup(self, updates, should_group) - def _update_non_slot(self, colocate_with, fn, *args, **kwargs): + def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs): assert isinstance(colocate_with, list) + should_group = options.pop("grouped") + assert not options # Validate that we are processing all of the options. # TODO(josh11b): In eager mode, use one thread per device. updates = {} for d in colocate_with: @@ -649,7 +653,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name): updates[d] = fn(*values.select_device_mirrored(d, args), **values.select_device_mirrored(d, kwargs)) - return values.regroup(updates, values.Mirrored) + return values.update_regroup(self, updates, should_group) def read_var(self, tower_local_var): """Read the aggregate value of a tower-local variable.""" diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index f51e543624..eeac528329 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -826,7 +826,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): with dist.scope(): ret_v_sum = dist.call_for_each_tower(model_fn, run_concurrently=False) - update_ops = dist.unwrap(dist.update(ret_v_sum, update, 5.0)) + update_ops = dist.update(ret_v_sum, update, 5.0, grouped=False) # Initialize variables. self.evaluate(variables.global_variables_initializer()) diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py index 23b220f64b..f525919048 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -141,14 +141,21 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): else: assert False - def _update(self, var, fn, *args, **kwargs): - with ops.device(self._device), distribute_lib.UpdateContext(self._device): - return fn(var, *args, **kwargs) + def _update(self, var, options, fn, *args, **kwargs): + # The implementations of _update() and _update_non_slot() are identical + # except _update() passes `var` as the first argument to `fn()`. + return self._update_non_slot(var, options, fn, var, *args, **kwargs) - def _update_non_slot(self, colocate_with, fn, *args, **kwargs): + def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs): del colocate_with + should_group = options.pop("grouped") + assert not options # Validate that we are processing all of the options. with ops.device(self._device), distribute_lib.UpdateContext(self._device): - return fn(*args, **kwargs) + result = fn(*args, **kwargs) + if should_group: + return result + else: + return nest.map_structure(self._unwrap, result) def read_var(self, tower_local_var): """Read the aggregate value of a tower-local variable.""" diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py index 1125d027f6..6ddd91507b 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py @@ -343,21 +343,33 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): return nest.map_structure(_select_fn, structured) - def _update(self, var, fn, *args, **kwargs): + def _update(self, var, options, fn, *args, **kwargs): if isinstance(var, values.AggregatingVariable): var = var.get() if not isinstance(var, resource_variable_ops.ResourceVariable): raise ValueError( "You can not update `var` %r. It must be a Variable." % var) + should_group = options.pop("grouped") + assert not options # Validate that we are processing all of the options. with ops.colocate_with(var), distribute_lib.UpdateContext(var.device): - return fn(var, *self._select_single_value(args), - **self._select_single_value(kwargs)) + result = fn(var, *self._select_single_value(args), + **self._select_single_value(kwargs)) + if should_group: + return result + else: + return nest.map_structure(self._unwrap, result) # TODO(yuefengz): does it need to call _select_single_value? - def _update_non_slot(self, colocate_with, fn, *args, **kwargs): + def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs): + should_group = options.pop("grouped") + assert not options # Validate that we are processing all of the options. with ops.device( colocate_with.device), distribute_lib.UpdateContext(colocate_with): - return fn(*args, **kwargs) + result = fn(*args, **kwargs) + if should_group: + return result + else: + return nest.map_structure(self._unwrap, result) def _unwrap(self, val): if isinstance(val, values.DistributedValues): diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py index 12789e0bc9..353d11a583 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py @@ -395,7 +395,8 @@ class ParameterServerStrategyTestBase( # TODO(yuefengz): support non-Mirrored variable as destinations. g = d.reduce( variable_scope.VariableAggregation.SUM, g, destinations=v) - with ops.control_dependencies(d.unwrap(d.update(v, update, g))): + with ops.control_dependencies( + d.update(v, update, g, grouped=False)): after_list.append(d.read_var(v)) return before_list, after_list diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py index 5d498fb629..fd280f5754 100644 --- a/tensorflow/contrib/distribute/python/strategy_test_lib.py +++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py @@ -115,7 +115,8 @@ class DistributionTestBase(test.TestCase): with ops.control_dependencies([fetched]): g = d.reduce( variable_scope.VariableAggregation.SUM, g, destinations=v) - with ops.control_dependencies(d.unwrap(d.update(v, update, g))): + with ops.control_dependencies(d.update( + v, update, g, grouped=False)): after_list.append(d.read_var(v)) return before_list, after_list @@ -169,7 +170,8 @@ class DistributionTestBase(test.TestCase): with ops.control_dependencies([fetched]): g = d.reduce( variable_scope.VariableAggregation.SUM, g, destinations=v) - with ops.control_dependencies(d.unwrap(d.update(v, update, g))): + with ops.control_dependencies(d.update( + v, update, g, grouped=False)): after_list.append(d.read_var(v)) return before_list, after_list diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index 1b555482d3..c3c7df3cd8 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -297,6 +297,7 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): # For outputs that have already been aggregated, take the first value # from the list as each value should be the same. Else return the full # list of values. + # TODO(josh11b): If aggregation is NONE, we should return a PerDevice value. if aggregation is not variables_lib.VariableAggregation.NONE: # TODO(priyag): Should this return the element or a list with 1 element last_step_tensor_outputs_dict[name] = output[0] @@ -398,11 +399,16 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): return output * (1. / len(value)) return output - def _update(self, var, fn, *args, **kwargs): - # TODO(jhseu): Consider supporting grouped==False. + def _update(self, var, options, fn, *args, **kwargs): assert isinstance(var, values.TPUMirroredVariable) + should_group = options.pop("grouped") + assert not options # Validate that we are processing all of the options. + if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access - return fn(var, *args, **kwargs) + if should_group: + return fn(var, *args, **kwargs) + else: + return [fn(var, *args, **kwargs)] # Otherwise, we revert to MirroredStrategy behavior and update each variable # directly. @@ -414,23 +420,25 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): updates[d] = fn(v, *values.select_device_mirrored(d, args), **values.select_device_mirrored(d, kwargs)) + return values.update_regroup(self, updates, should_group) - # Make a single control dependency to keep the variables mirrored. If one - # assignment is fetched, then run all assignments. - sorted_keys = sorted(updates.keys()) - update_tuple = control_flow_ops.tuple([updates[d] for d in sorted_keys]) - for i, d in enumerate(sorted_keys): - updates[d] = update_tuple[i] - return values.regroup(updates, values.Mirrored) + # TODO(josh11b): Need to implement _update_non_slot()! def read_var(self, var): assert isinstance(var, values.TPUMirroredVariable) return var.read_value() - def _unwrap(self, value): - if isinstance(value, list): - return value - return [value] + def _unwrap(self, val): + if isinstance(val, values.DistributedValues): + # Return in a deterministic order. + return [val.get(device=d) for d in sorted(val.devices)] + elif isinstance(val, list): + # TODO(josh11b): We need to remove this case; per device values should + # be represented using a PerDevice wrapper instead of a list with + # one entry per device. + return val + return [val] + @property def num_towers(self): diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index c18faeb67d..18ceba42c2 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -366,18 +366,7 @@ class MirroredVariable(DistributedVariable, Mirrored, # We are calling assign on the mirrored variable in cross tower context, # use update to update the variable. strategy = distribution_strategy_context.get_distribution_strategy() - updates = strategy.update(self, f, *args, **kwargs) - grouped = strategy.group(updates) - if isinstance(updates, DistributedValues) and updates.is_tensor_like: - # Make sure we run all updates. Without this, something like - # session.run(mirrored_var.assign*(...)) may only update one tower. - index = {} - for d in updates.devices: - with ops.device(d), ops.control_dependencies([grouped]): - index[d] = array_ops.identity(updates.get(d)) - return Mirrored(index) - else: - return grouped + return strategy.update(self, f, *args, **kwargs) else: _assert_tower_context() # We are calling an assign function on the mirrored variable in tower @@ -1049,6 +1038,29 @@ def select_device_mirrored(device, structured): return nest.map_structure(_get_mirrored, structured) +def update_regroup(strategy, updates, should_group): + """Regroup for an update, with dependencies to ensure all updates execute.""" + regrouped = regroup(updates, Mirrored) + if not should_group: + return nest.map_structure(strategy.unwrap, regrouped) + grouped_flat = [] + for u in nest.flatten(regrouped): + if isinstance(u, DistributedValues): + g = strategy.group(u) + if u.is_tensor_like: + # Make sure we run all updates. Without this, something like + # session.run(strategy.update(...)) may only update one tower. + index = {} + for d in u.devices: + with ops.device(d), ops.control_dependencies([g]): + index[d] = array_ops.identity(u.get(d)) + g = Mirrored(index) + else: + g = u + grouped_flat.append(g) + return nest.pack_sequence_as(regrouped, grouped_flat) + + class PerDeviceDataIterator(object): """An iterator (like `tf.data.Iterator`) into a `PerDeviceDataset`.""" diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py index 6af59dcfbf..53e27c08c4 100644 --- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py +++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py @@ -30,7 +30,6 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import distribute as distribute_lib @@ -965,8 +964,7 @@ class OptimizerV2(optimizer_v1.Optimizer): # Use the processors to update the variables. update_ops = [] for grad, var in grads_and_vars: - update_ops.extend(distribution.unwrap(distribution.update( - var, update, grad))) + update_ops.extend(distribution.update(var, update, grad, grouped=False)) # Give the child class a chance to do something after applying # gradients @@ -978,26 +976,24 @@ class OptimizerV2(optimizer_v1.Optimizer): update_ops = control_flow_ops.group(update_ops) with ops.control_dependencies([update_ops]): - finish_updates = distribution.update_non_slot(non_slot_devices, finish) - if finish_updates is None: - finish_updates = update_ops + finish_updates = distribution.update_non_slot( + non_slot_devices, finish, grouped=False) + # We said grouped=False, which means finish_updates is always a list. + # It will be [None] when finish() returns None. + if finish_updates == [None]: + finish_updates = [update_ops] # Update `global_step` (if any). if global_step is None: apply_updates = distribution.group(finish_updates, name=name) else: - with ops.control_dependencies(distribution.unwrap(finish_updates)): - - def update_global_step(global_step): - if isinstance(global_step, resource_variable_ops.ResourceVariable): - return global_step.assign_add( - ops.convert_to_tensor(1, dtype=global_step.dtype), - read_value=False) - else: - return state_ops.assign_add(global_step, 1) - - apply_updates = distribution.group( - distribution.update(global_step, update_global_step), name=name) + with ops.control_dependencies(finish_updates): + + def update_global_step(global_step, name): + return global_step.assign_add(1, read_value=False, name=name) + + apply_updates = distribution.update( + global_step, update_global_step, name) # Add the training op to the TRAIN_OP graph collection in graph mode. if not eager_execution: diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py index 419a9ec12b..a92a1bdee7 100644 --- a/tensorflow/python/training/distribute.py +++ b/tensorflow/python/training/distribute.py @@ -26,7 +26,6 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops.losses import losses_impl from tensorflow.python.platform import tf_logging @@ -807,15 +806,22 @@ class DistributionStrategy(object): var: Variable, possibly mirrored to multiple devices, to operate on. fn: Function to call. Should take the variable as the first argument. *args: Additional positional arguments to pass to `fn()`. - **kwargs: Keyword arguments to pass to `fn()`. + **kwargs: Keyword arguments to pass to `fn()`. If "grouped=False" is + specified, the return value will be unwrapped. Returns: - Merged return value of `fn` across all towers. + By default, the merged return value of `fn` across all towers. The merged + result has dependencies to make sure that if it is evaluated at all, the + side effects (updates) will happen on every tower. If instead + "grouped=False" is specified, this function will return a nest of lists + where each list has an element per tower, and the caller is responsible + for ensuring all elements are executed. """ _require_cross_tower_context(self) - return self._update(var, fn, *args, **kwargs) + options = {"grouped": kwargs.pop("grouped", True)} + return self._update(var, options, fn, *args, **kwargs) - def _update(self, var, fn, *args, **kwargs): + def _update(self, var, options, fn, *args, **kwargs): raise NotImplementedError("must be implemented in descendants") def update_non_slot(self, colocate_with, fn, *args, **kwargs): @@ -825,15 +831,18 @@ class DistributionStrategy(object): colocate_with: The return value of `non_slot_devices()`. fn: Function to execute. *args: Positional arguments to pass to `fn()`. - **kwargs: Keyword arguments to pass to `fn()`. + **kwargs: Keyword arguments to pass to `fn()`. If "grouped=False" is + specified, the return value will be unwrapped and the caller is + responsible for ensuring all elements are executed. Returns: Return value of `fn`, possibly merged across devices. """ _require_cross_tower_context(self) - return self._update_non_slot(colocate_with, fn, *args, **kwargs) + options = {"grouped": kwargs.pop("grouped", True)} + return self._update_non_slot(colocate_with, options, fn, *args, **kwargs) - def _update_non_slot(self, colocate_with, fn, *args, **kwargs): + def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs): raise NotImplementedError("must be implemented in descendants") def unwrap(self, value): @@ -1134,17 +1143,22 @@ class _DefaultDistributionStrategy(DistributionStrategy): del aggregation, destinations return value - def _update(self, var, fn, *args, **kwargs): - # TODO(josh11b): Figure out what we should be passing to UpdateContext() - # once that value is used for something. - with ops.colocate_with(var), UpdateContext(var): - return fn(var, *args, **kwargs) + def _update(self, var, options, fn, *args, **kwargs): + # The implementations of _update() and _update_non_slot() are identical + # except _update() passes `var` as the first argument to `fn()`. + return self._update_non_slot(var, options, fn, var, *args, **kwargs) - def _update_non_slot(self, colocate_with, fn, *args, **kwargs): + def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs): + should_group = options.pop("grouped") + assert not options # Validate that we are processing all of the options. # TODO(josh11b): Figure out what we should be passing to UpdateContext() # once that value is used for something. with ops.colocate_with(colocate_with), UpdateContext(colocate_with): - return fn(*args, **kwargs) + result = fn(*args, **kwargs) + if should_group: + return result + else: + return nest.map_structure(self._unwrap, result) def read_var(self, tower_local_var): return array_ops.identity(tower_local_var) @@ -1193,13 +1207,10 @@ class _DefaultDistributionStrategy(DistributionStrategy): def increment_var(v, amount=1): """`v += amount`, distributed-aware version.""" def update(vu): - if isinstance(vu, resource_variable_ops.ResourceVariable): - return vu.assign_add(amount, read_value=False) - else: - return state_ops.assign_add(vu, amount) + return vu.assign_add(amount, read_value=False) def merge_fn(dist, vm): - return dist.group(dist.update(vm, update)) + return dist.update(vm, update) tower_context = distribution_strategy_context.get_tower_context() return tower_context.merge_call(merge_fn, v) diff --git a/tensorflow/python/training/distribution_strategy_context.py b/tensorflow/python/training/distribution_strategy_context.py index 998b5c35ce..ce580a406f 100644 --- a/tensorflow/python/training/distribution_strategy_context.py +++ b/tensorflow/python/training/distribution_strategy_context.py @@ -89,6 +89,7 @@ def get_tower_context(): """Returns the current TowerContext or None if in a cross-tower context. Note that execution: + 1. starts in the default (single-tower) tower context (this function will return the default TowerContext object); 2. switches to cross-tower context (in which case this will return @@ -121,6 +122,7 @@ def get_cross_tower_context(): """Returns the current DistributionStrategy if in a cross-tower context. Note that execution: + 1. starts in the default (single-tower) tower context; 2. switches to cross-tower context when entering a `with DistributionStrategy.scope():` block; diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py index 30b0ed20c8..47034919e1 100644 --- a/tensorflow/python/training/optimizer.py +++ b/tensorflow/python/training/optimizer.py @@ -692,7 +692,7 @@ class Optimizer( update_ops = [ op for grad, var in grads_and_vars - for op in distribution.unwrap(distribution.update(var, update, grad)) + for op in distribution.update(var, update, grad, grouped=False) ] def finish(self, update_ops): @@ -700,13 +700,13 @@ class Optimizer( non_slot_devices = distribution.non_slot_devices(var_list) finish_updates = distribution.update_non_slot( - non_slot_devices, finish, self, update_ops) + non_slot_devices, finish, self, update_ops, grouped=False) if global_step is None: apply_updates = distribution.group(finish_updates, name=name) else: - with ops.control_dependencies(distribution.unwrap(finish_updates)): - apply_updates = distribution.group(distribution.update( - global_step, state_ops.assign_add, 1, name=name)) + with ops.control_dependencies(finish_updates): + apply_updates = distribution.update( + global_step, state_ops.assign_add, 1, name=name) if not context.executing_eagerly(): if isinstance(apply_updates, ops.Tensor): |