diff options
Diffstat (limited to 'tensorflow/python/training/distribute.py')
-rw-r--r-- | tensorflow/python/training/distribute.py | 104 |
1 files changed, 30 insertions, 74 deletions
diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py index 6a326b65bb..c719045c7f 100644 --- a/tensorflow/python/training/distribute.py +++ b/tensorflow/python/training/distribute.py @@ -221,11 +221,11 @@ def has_distribution_strategy(): def get_loss_reduction(): - """Reduce `method_string` corresponding to the last loss reduction.""" + """Reduce `aggregation` corresponding to the last loss reduction.""" loss_reduction = ops.get_default_graph()._last_loss_reduction # pylint: disable=protected-access if loss_reduction == losses_impl.Reduction.SUM: - return "sum" - return "mean" + return variable_scope.VariableAggregation.SUM + return variable_scope.VariableAggregation.MEAN # ------------------------------------------------------------------------------ @@ -539,8 +539,8 @@ class DistributionStrategy(object): 1. Wrap your input dataset in `d.distribute_dataset()` and create an iterator. 2. Define each tower `d.call_for_each_tower()` up to the point of getting a list of gradient, variable pairs. - 3. Call `d.reduce("sum", t, v)` or `d.batch_reduce()` to sum the - gradients (with locality T) into values with locality V(`v`). + 3. Call `d.reduce(VariableAggregation.SUM, t, v)` or `d.batch_reduce()` to sum + the gradients (with locality T) into values with locality V(`v`). 4. Call `d.update(v)` for each variable to update its value. Steps 3 and 4 are done automatically by class `Optimizer` if you call @@ -614,43 +614,6 @@ class DistributionStrategy(object): # Note: should support "colocate_with" argument. raise NotImplementedError("must be implemented in descendants") - def tower_local_var_scope(self, reduce_method): - """Inside this scope, new variables will not be mirrored. - - There will still be one component variable per tower, but there is - no requirement that they stay in sync. Instead, when saving them - or calling `read_var()`, we use the value that results when - calling `reduce()` on all the towers' variables. - - Note: tower-local implies not trainable. Instead, it is expected - that each tower will directly update (using `assign_add()` or - whatever) its local variable instance but only the aggregated - value (accessible using `read_var()`) will be exported from the - model. When it is acceptable to only aggregate on export, we - greatly reduce communication overhead by using tower-local - variables. - - Note: All component variables will be initialized to the same - value, using the initialization expression from the first tower. - The values will match even if the initialization expression uses - random numbers. - - Args: - reduce_method: String used as a `method_string` to `reduce()` - to get the value to save when checkpointing. - - Returns: - A context manager. - """ - def create_tower_local_variable(next_creator, *args, **kwargs): - _require_distribution_strategy_scope(self) - kwargs["use_resource"] = True - kwargs["tower_local_reduce_method"] = reduce_method - return next_creator(*args, **kwargs) - - _require_distribution_strategy_scope(self) - return variable_scope.variable_creator_scope(create_tower_local_variable) - def read_var(self, v): """Reads the value of a variable. @@ -816,12 +779,12 @@ class DistributionStrategy(object): def _call_for_each_tower(self, fn, *args, **kwargs): raise NotImplementedError("must be implemented in descendants") - def reduce(self, method_string, value, destinations=None): + def reduce(self, aggregation, value, destinations=None): """Combine (via e.g. sum or mean) values across towers. Args: - method_string: A string indicating how to combine values, either - "sum" or "mean". + aggregation: Indicates how a variable will be aggregated. Accepted values + are @{tf.VariableAggregation.SUM}, @{tf.VariableAggregation.MEAN}. value: A per-device value with one value per tower. destinations: An optional mirrored variable, a device string, list of device strings. The return value will be copied to all @@ -836,18 +799,21 @@ class DistributionStrategy(object): # TODO(josh11b): Return an unwrapped value if colocate_with is a # single device. _require_cross_tower_context(self) - assert method_string in ("sum", "mean") - return self._reduce(method_string, value, destinations) + assert aggregation in [ + variable_scope.VariableAggregation.SUM, + variable_scope.VariableAggregation.MEAN + ] + return self._reduce(aggregation, value, destinations) - def _reduce(self, method_string, value, destinations): + def _reduce(self, aggregation, value, destinations): raise NotImplementedError("must be implemented in descendants") - def batch_reduce(self, method_string, value_destination_pairs): + def batch_reduce(self, aggregation, value_destination_pairs): """Combine multiple `reduce` calls into one for faster execution. Args: - method_string: A string indicating how to combine values, either - "sum" or "mean". + aggregation: Indicates how a variable will be aggregated. Accepted values + are @{tf.VariableAggregation.SUM}, @{tf.VariableAggregation.MEAN}. value_destination_pairs: A sequence of (value, destinations) pairs. See `reduce()` for a description. @@ -856,12 +822,17 @@ class DistributionStrategy(object): """ # TODO(josh11b): More docstring _require_cross_tower_context(self) - assert method_string in ("sum", "mean") - return self._batch_reduce(method_string, value_destination_pairs) - - def _batch_reduce(self, method_string, value_destination_pairs): - return [self.reduce(method_string, t, destinations=v) - for t, v in value_destination_pairs] + assert aggregation in [ + variable_scope.VariableAggregation.SUM, + variable_scope.VariableAggregation.MEAN + ] + return self._batch_reduce(aggregation, value_destination_pairs) + + def _batch_reduce(self, aggregation, value_destination_pairs): + return [ + self.reduce(aggregation, t, destinations=v) + for t, v in value_destination_pairs + ] def update(self, var, fn, *args, **kwargs): """Run `fn` to update `var` using inputs mirrored to the same devices. @@ -1090,10 +1061,6 @@ class TowerContext(object): finally: _pop_per_thread_mode() - def tower_local_var_scope(self, reduce_method): - """Alias for distribution_strategy.tower_local_var_scope().""" - return self._distribution_strategy.tower_local_var_scope(reduce_method) - @property def is_single_tower(self): """Returns whether there is a single tower or multiple.""" @@ -1140,22 +1107,11 @@ class _DefaultDistributionStrategy(DistributionStrategy): def creator(next_creator, *args, **kwargs): _require_distribution_strategy_scope(self) - kwargs.pop("tower_local_reduce_method", None) return next_creator(*args, **kwargs) return _CurrentDistributionContext( self, variable_scope.variable_creator_scope(creator)) - def tower_local_var_scope(self, reduce_method): - """Does not set to resource variables.""" - def create_tower_local_variable(next_creator, *args, **kwargs): - _require_distribution_strategy_scope(self) - kwargs["trainable"] = False - return next_creator(*args, **kwargs) - - _require_distribution_strategy_scope(self) - return variable_scope.variable_creator_scope(create_tower_local_variable) - def colocate_vars_with(self, colocate_with_variable): """Does not require `self.scope`.""" _require_distribution_strategy_scope(self) @@ -1176,9 +1132,9 @@ class _DefaultDistributionStrategy(DistributionStrategy): with TowerContext(self, tower_id=0): return fn(*args, **kwargs) - def _reduce(self, method_string, value, destinations): + def _reduce(self, aggregation, value, destinations): # TODO(josh11b): Use destinations? - del method_string, destinations + del aggregation, destinations return value def _update(self, var, fn, *args, **kwargs): |