aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/distribute.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/distribute.py')
-rw-r--r--tensorflow/python/training/distribute.py104
1 files changed, 30 insertions, 74 deletions
diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py
index 6a326b65bb..c719045c7f 100644
--- a/tensorflow/python/training/distribute.py
+++ b/tensorflow/python/training/distribute.py
@@ -221,11 +221,11 @@ def has_distribution_strategy():
def get_loss_reduction():
- """Reduce `method_string` corresponding to the last loss reduction."""
+ """Reduce `aggregation` corresponding to the last loss reduction."""
loss_reduction = ops.get_default_graph()._last_loss_reduction # pylint: disable=protected-access
if loss_reduction == losses_impl.Reduction.SUM:
- return "sum"
- return "mean"
+ return variable_scope.VariableAggregation.SUM
+ return variable_scope.VariableAggregation.MEAN
# ------------------------------------------------------------------------------
@@ -539,8 +539,8 @@ class DistributionStrategy(object):
1. Wrap your input dataset in `d.distribute_dataset()` and create an iterator.
2. Define each tower `d.call_for_each_tower()` up to the point of
getting a list of gradient, variable pairs.
- 3. Call `d.reduce("sum", t, v)` or `d.batch_reduce()` to sum the
- gradients (with locality T) into values with locality V(`v`).
+ 3. Call `d.reduce(VariableAggregation.SUM, t, v)` or `d.batch_reduce()` to sum
+ the gradients (with locality T) into values with locality V(`v`).
4. Call `d.update(v)` for each variable to update its value.
Steps 3 and 4 are done automatically by class `Optimizer` if you call
@@ -614,43 +614,6 @@ class DistributionStrategy(object):
# Note: should support "colocate_with" argument.
raise NotImplementedError("must be implemented in descendants")
- def tower_local_var_scope(self, reduce_method):
- """Inside this scope, new variables will not be mirrored.
-
- There will still be one component variable per tower, but there is
- no requirement that they stay in sync. Instead, when saving them
- or calling `read_var()`, we use the value that results when
- calling `reduce()` on all the towers' variables.
-
- Note: tower-local implies not trainable. Instead, it is expected
- that each tower will directly update (using `assign_add()` or
- whatever) its local variable instance but only the aggregated
- value (accessible using `read_var()`) will be exported from the
- model. When it is acceptable to only aggregate on export, we
- greatly reduce communication overhead by using tower-local
- variables.
-
- Note: All component variables will be initialized to the same
- value, using the initialization expression from the first tower.
- The values will match even if the initialization expression uses
- random numbers.
-
- Args:
- reduce_method: String used as a `method_string` to `reduce()`
- to get the value to save when checkpointing.
-
- Returns:
- A context manager.
- """
- def create_tower_local_variable(next_creator, *args, **kwargs):
- _require_distribution_strategy_scope(self)
- kwargs["use_resource"] = True
- kwargs["tower_local_reduce_method"] = reduce_method
- return next_creator(*args, **kwargs)
-
- _require_distribution_strategy_scope(self)
- return variable_scope.variable_creator_scope(create_tower_local_variable)
-
def read_var(self, v):
"""Reads the value of a variable.
@@ -816,12 +779,12 @@ class DistributionStrategy(object):
def _call_for_each_tower(self, fn, *args, **kwargs):
raise NotImplementedError("must be implemented in descendants")
- def reduce(self, method_string, value, destinations=None):
+ def reduce(self, aggregation, value, destinations=None):
"""Combine (via e.g. sum or mean) values across towers.
Args:
- method_string: A string indicating how to combine values, either
- "sum" or "mean".
+ aggregation: Indicates how a variable will be aggregated. Accepted values
+ are @{tf.VariableAggregation.SUM}, @{tf.VariableAggregation.MEAN}.
value: A per-device value with one value per tower.
destinations: An optional mirrored variable, a device string,
list of device strings. The return value will be copied to all
@@ -836,18 +799,21 @@ class DistributionStrategy(object):
# TODO(josh11b): Return an unwrapped value if colocate_with is a
# single device.
_require_cross_tower_context(self)
- assert method_string in ("sum", "mean")
- return self._reduce(method_string, value, destinations)
+ assert aggregation in [
+ variable_scope.VariableAggregation.SUM,
+ variable_scope.VariableAggregation.MEAN
+ ]
+ return self._reduce(aggregation, value, destinations)
- def _reduce(self, method_string, value, destinations):
+ def _reduce(self, aggregation, value, destinations):
raise NotImplementedError("must be implemented in descendants")
- def batch_reduce(self, method_string, value_destination_pairs):
+ def batch_reduce(self, aggregation, value_destination_pairs):
"""Combine multiple `reduce` calls into one for faster execution.
Args:
- method_string: A string indicating how to combine values, either
- "sum" or "mean".
+ aggregation: Indicates how a variable will be aggregated. Accepted values
+ are @{tf.VariableAggregation.SUM}, @{tf.VariableAggregation.MEAN}.
value_destination_pairs: A sequence of (value, destinations)
pairs. See `reduce()` for a description.
@@ -856,12 +822,17 @@ class DistributionStrategy(object):
"""
# TODO(josh11b): More docstring
_require_cross_tower_context(self)
- assert method_string in ("sum", "mean")
- return self._batch_reduce(method_string, value_destination_pairs)
-
- def _batch_reduce(self, method_string, value_destination_pairs):
- return [self.reduce(method_string, t, destinations=v)
- for t, v in value_destination_pairs]
+ assert aggregation in [
+ variable_scope.VariableAggregation.SUM,
+ variable_scope.VariableAggregation.MEAN
+ ]
+ return self._batch_reduce(aggregation, value_destination_pairs)
+
+ def _batch_reduce(self, aggregation, value_destination_pairs):
+ return [
+ self.reduce(aggregation, t, destinations=v)
+ for t, v in value_destination_pairs
+ ]
def update(self, var, fn, *args, **kwargs):
"""Run `fn` to update `var` using inputs mirrored to the same devices.
@@ -1090,10 +1061,6 @@ class TowerContext(object):
finally:
_pop_per_thread_mode()
- def tower_local_var_scope(self, reduce_method):
- """Alias for distribution_strategy.tower_local_var_scope()."""
- return self._distribution_strategy.tower_local_var_scope(reduce_method)
-
@property
def is_single_tower(self):
"""Returns whether there is a single tower or multiple."""
@@ -1140,22 +1107,11 @@ class _DefaultDistributionStrategy(DistributionStrategy):
def creator(next_creator, *args, **kwargs):
_require_distribution_strategy_scope(self)
- kwargs.pop("tower_local_reduce_method", None)
return next_creator(*args, **kwargs)
return _CurrentDistributionContext(
self, variable_scope.variable_creator_scope(creator))
- def tower_local_var_scope(self, reduce_method):
- """Does not set to resource variables."""
- def create_tower_local_variable(next_creator, *args, **kwargs):
- _require_distribution_strategy_scope(self)
- kwargs["trainable"] = False
- return next_creator(*args, **kwargs)
-
- _require_distribution_strategy_scope(self)
- return variable_scope.variable_creator_scope(create_tower_local_variable)
-
def colocate_vars_with(self, colocate_with_variable):
"""Does not require `self.scope`."""
_require_distribution_strategy_scope(self)
@@ -1176,9 +1132,9 @@ class _DefaultDistributionStrategy(DistributionStrategy):
with TowerContext(self, tower_id=0):
return fn(*args, **kwargs)
- def _reduce(self, method_string, value, destinations):
+ def _reduce(self, aggregation, value, destinations):
# TODO(josh11b): Use destinations?
- del method_string, destinations
+ del aggregation, destinations
return value
def _update(self, var, fn, *args, **kwargs):