aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/cross_tower_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/distribute/python/cross_tower_ops.py')
-rw-r--r--tensorflow/contrib/distribute/python/cross_tower_ops.py71
1 files changed, 39 insertions, 32 deletions
diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops.py b/tensorflow/contrib/distribute/python/cross_tower_ops.py
index 0261ce43fa..b0baf0dad1 100644
--- a/tensorflow/contrib/distribute/python/cross_tower_ops.py
+++ b/tensorflow/contrib/distribute/python/cross_tower_ops.py
@@ -28,6 +28,7 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import device_util
@@ -88,7 +89,7 @@ def _simple_broadcast(value, destinations):
def _simple_reduce(per_device_value, reduce_to_device, accumulation_fn,
- method_string):
+ aggregation):
# pylint: disable=g-missing-docstring
all_values = []
count = 0
@@ -112,11 +113,12 @@ def _simple_reduce(per_device_value, reduce_to_device, accumulation_fn,
with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
reduced = cross_tower_utils.aggregate_tensors_or_indexed_slices(
all_values, accumulation_fn)
- if method_string == "mean":
+ if aggregation == vs.VariableAggregation.MEAN:
reduced = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(
reduced, count)
- elif method_string != "sum":
- raise ValueError("`method_string` must be 'sum' or 'mean'")
+ elif aggregation != vs.VariableAggregation.SUM:
+ raise ValueError("`aggregation` must be VariableAggregation.SUM "
+ "or VariableAggregation.MEAN.")
return reduced
@@ -126,14 +128,15 @@ class CrossTowerOps(object):
def __init__(self):
pass
- def reduce(self, method_string, per_device_value, destinations=None):
+ def reduce(self, aggregation, per_device_value, destinations=None):
"""Reduce `per_device_value` to `destinations`.
- It runs the reduction operation defined by `method_string` and put the
+ It runs the reduction operation defined by `aggregation` and put the
result on `destinations`.
Args:
- method_string: either 'sum' or 'mean' specifying the reduction method.
+ aggregation: Indicates how a variable will be aggregated. Accepted values
+ are @{tf.VariableAggregation.SUM}, @{tf.VariableAggregation.MEAN}.
per_device_value: a PerDevice object.
destinations: the reduction destinations.
@@ -147,16 +150,17 @@ class CrossTowerOps(object):
raise ValueError("`per_device_value` must be a `PerDevice` object.")
if destinations is not None:
validate_destinations(destinations)
- return self._reduce(method_string, per_device_value, destinations)
+ return self._reduce(aggregation, per_device_value, destinations)
- def batch_reduce(self, method_string, value_destination_pairs):
+ def batch_reduce(self, aggregation, value_destination_pairs):
"""Reduce PerDevice objects in a batch.
Reduce each first element in `value_destination_pairs` to each second
element which indicates the destinations.
Args:
- method_string: either 'sum' or 'mean' specifying the reduction method.
+ aggregation: Indicates how a variable will be aggregated. Accepted values
+ are @{tf.VariableAggregation.SUM}, @{tf.VariableAggregation.MEAN}.
value_destination_pairs: a list or a tuple of tuples of PerDevice objects
and destinations. If a destination is None, then the destinations
are set to match the devices of the input PerDevice object.
@@ -175,7 +179,7 @@ class CrossTowerOps(object):
if d is not None:
validate_destinations(d)
- return self._batch_reduce(method_string, value_destination_pairs)
+ return self._batch_reduce(aggregation, value_destination_pairs)
def broadcast(self, tensor, destinations):
"""Broadcast the `tensor` to destinations.
@@ -190,11 +194,11 @@ class CrossTowerOps(object):
validate_destinations(destinations)
return self._broadcast(tensor, destinations)
- def _reduce(self, method_string, per_device_value, destinations):
+ def _reduce(self, aggregation, per_device_value, destinations):
raise NotImplementedError(
"_reduce method must be implemented in descendants.")
- def _batch_reduce(self, method_string, value_destination_pairs):
+ def _batch_reduce(self, aggregation, value_destination_pairs):
raise NotImplementedError(
"_batch_reduce method must be implemented in descendants.")
@@ -220,16 +224,18 @@ class ReductionToOneDeviceCrossTowerOps(CrossTowerOps):
self.accumulation_fn = accumulation_fn
super(ReductionToOneDeviceCrossTowerOps, self).__init__()
- def _reduce(self, method_string, per_device_value, destinations):
+ def _reduce(self, aggregation, per_device_value, destinations):
devices = get_devices_from(destinations or per_device_value)
reduce_to_device = self.reduce_to_device or devices[0]
reduced = _simple_reduce(per_device_value, reduce_to_device,
- self.accumulation_fn, method_string)
+ self.accumulation_fn, aggregation)
return self.broadcast(reduced, devices)
- def _batch_reduce(self, method_string, value_destination_pairs):
- return [self._reduce(method_string, t, destinations=v)
- for t, v in value_destination_pairs]
+ def _batch_reduce(self, aggregation, value_destination_pairs):
+ return [
+ self._reduce(aggregation, t, destinations=v)
+ for t, v in value_destination_pairs
+ ]
def _group_value_by_device(per_device_values):
@@ -260,18 +266,19 @@ def _group_value_by_device(per_device_values):
return grouped
-def _ungroup_and_make_mirrored(grouped_reduced, destinations, method_string):
+def _ungroup_and_make_mirrored(grouped_reduced, destinations, aggregation):
"""Ungroup results from all-reduce and make Mirrored objects.
Each all-reduce result will be divided by the number of destinations before
- Mirrored objects are created if method_string is "mean".
+ Mirrored objects are created if aggregation is "mean".
Args:
grouped_reduced: a list of lists, each sublist has components for each
device, paired with a None. It is the result from
cross_tower_utils.aggregate_gradients_using*.
destinations: a list of device strings for returned Mirrored objects.
- method_string: "mean" or "sum".
+ aggregation: Indicates how a variable will be aggregated. Accepted values
+ are @{tf.VariableAggregation.SUM}, @{tf.VariableAggregation.MEAN}.
Returns:
a list of Mirrored objects.
@@ -279,7 +286,7 @@ def _ungroup_and_make_mirrored(grouped_reduced, destinations, method_string):
index = [{} for _ in range(len(grouped_reduced[0]))]
for d, per_device_reduced in enumerate(grouped_reduced):
for i, (v, _) in enumerate(per_device_reduced):
- if method_string == "mean":
+ if aggregation == vs.VariableAggregation.MEAN:
index[i][destinations[d]] = v / len(destinations)
else:
index[i][destinations[d]] = v
@@ -488,13 +495,13 @@ class AllReduceCrossTowerOps(CrossTowerOps):
self._agg_small_grads_max_group = agg_small_grads_max_group
super(AllReduceCrossTowerOps, self).__init__()
- def _reduce(self, method_string, per_device_value, destinations):
+ def _reduce(self, aggregation, per_device_value, destinations):
contains_indexed_slices = cross_tower_utils.contains_indexed_slices(
per_device_value)
if ((destinations is None or _devices_match(per_device_value, destinations))
and not context.executing_eagerly()
and not contains_indexed_slices):
- return self._batch_all_reduce(method_string, [per_device_value])[0]
+ return self._batch_all_reduce(aggregation, [per_device_value])[0]
else:
if contains_indexed_slices:
logging.log_first_n(
@@ -504,16 +511,16 @@ class AllReduceCrossTowerOps(CrossTowerOps):
devices = get_devices_from(destinations or per_device_value)
reduce_to_device = devices[0]
reduced = _simple_reduce(per_device_value, reduce_to_device,
- math_ops.add_n, method_string)
+ math_ops.add_n, aggregation)
return self.broadcast(reduced, devices)
- def _batch_reduce(self, method_string, value_destination_pairs):
+ def _batch_reduce(self, aggregation, value_destination_pairs):
all_devices_match = _all_devices_match(value_destination_pairs)
contains_indexed_slices = cross_tower_utils.contains_indexed_slices(
value_destination_pairs)
if (all_devices_match and not context.executing_eagerly()
and not contains_indexed_slices):
- return self._batch_all_reduce(method_string,
+ return self._batch_all_reduce(aggregation,
[v[0] for v in value_destination_pairs])
else:
if not all_devices_match:
@@ -521,11 +528,11 @@ class AllReduceCrossTowerOps(CrossTowerOps):
"destinations are different.")
return [
- self._reduce(method_string, t, destinations=v)
+ self._reduce(aggregation, t, destinations=v)
for t, v in value_destination_pairs
]
- def _batch_all_reduce(self, method_string, per_device_values):
+ def _batch_all_reduce(self, aggregation, per_device_values):
"""All reduce algorithm in a batch."""
logging.info(
"batch_all_reduce invoked for batches size = %d with "
@@ -556,7 +563,7 @@ class AllReduceCrossTowerOps(CrossTowerOps):
reduced = _unpack_tensors(reduced, tensor_packer)
return _ungroup_and_make_mirrored(reduced, per_device_values[0].devices,
- method_string)
+ aggregation)
AllReduceSpecTuple = collections.namedtuple("AllReduceSpecTuple",
@@ -635,7 +642,7 @@ class MultiWorkerAllReduce(AllReduceCrossTowerOps):
validate_and_complete_spec(spec) for spec in all_reduce_spec
]
- def _batch_all_reduce(self, method_string, per_device_values):
+ def _batch_all_reduce(self, aggregation, per_device_values):
"""All reduce algorithm in a batch."""
logging.info(
"distributed batch_all_reduce invoked for batches size = %d with "
@@ -682,7 +689,7 @@ class MultiWorkerAllReduce(AllReduceCrossTowerOps):
assert not remaining_grads
return _ungroup_and_make_mirrored(aggregated_grads, destinations,
- method_string)
+ aggregation)
_dgx1_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7],