diff options
Diffstat (limited to 'tensorflow/contrib/distribute/python/cross_tower_ops.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/cross_tower_ops.py | 69 |
1 files changed, 58 insertions, 11 deletions
diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops.py b/tensorflow/contrib/distribute/python/cross_tower_ops.py index 3a7addf221..2a653b0f10 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_ops.py +++ b/tensorflow/contrib/distribute/python/cross_tower_ops.py @@ -53,7 +53,7 @@ def validate_destinations(destinations): if not isinstance( destinations, (value_lib.DistributedValues, resource_variable_ops.ResourceVariable, - six.string_types, list)): + value_lib.AggregatingVariable, six.string_types, list)): raise ValueError("destinations must be one of a `DistributedValues` object," " a tf.Variable object, a device string, a list of device " "strings or None") @@ -62,7 +62,44 @@ def validate_destinations(destinations): raise ValueError("destinations can not be empty") +def _make_tensor_into_per_device(input_tensor): + """Converts a single tensor into a PerDevice object.""" + if isinstance(input_tensor, (tuple, list)): + raise ValueError("Cannot convert `input_tensor` to a `PerDevice` object, " + "got %r but expected a object that is not a tuple or list." + % (input_tensor,)) + if isinstance(input_tensor, value_lib.PerDevice): + return input_tensor + + try: + device = input_tensor.device + except AttributeError: + raise ValueError("Cannot convert `input_tensor` to a `PerDevice` object " + "because it doesn't have device set.") + + return value_lib.PerDevice({device: input_tensor}) + + +def _normalize_value_destination_pairs(value_destination_pairs): + """Converts each tensor into a PerDevice object in the input list.""" + result = [] + if not isinstance(value_destination_pairs, (list, tuple)): + raise ValueError("`value_destination_pairs` should be a list or tuple") + for pair in value_destination_pairs: + if not isinstance(pair, tuple): + raise ValueError( + "Each element of `value_destination_pairs` should be a tuple.") + if len(pair) != 2: + raise ValueError("Each element of `value_destination_pairs` should be a " + "tuple of size 2.") + + per_device = _make_tensor_into_per_device(pair[0]) + result.append((per_device, pair[1])) + return result + + def _validate_value_destination_pairs(value_destination_pairs): + # TODO(yuefengz): raise exceptions instead of returning False. # pylint: disable=g-missing-docstring if not value_destination_pairs: return False if not isinstance(value_destination_pairs, (list, tuple)): return False @@ -78,12 +115,15 @@ def _validate_value_destination_pairs(value_destination_pairs): def get_devices_from(destinations): if isinstance(destinations, value_lib.DistributedValues): return list(destinations.devices) - elif isinstance(destinations, resource_variable_ops.ResourceVariable): + elif isinstance(destinations, (resource_variable_ops.ResourceVariable, + value_lib.AggregatingVariable)): return [destinations.device] elif isinstance(destinations, six.string_types): return [device_util.resolve(destinations)] - else: + elif isinstance(destinations, (list, tuple)): return [device_util.resolve(destination) for destination in destinations] + else: + return [destinations.device] def _devices_match(left, right): @@ -158,7 +198,7 @@ class CrossTowerOps(object): Args: aggregation: Indicates how a variable will be aggregated. Accepted values are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`. - per_device_value: a PerDevice object. + per_device_value: a PerDevice object or a tensor with device set. destinations: the reduction destinations. Returns: @@ -168,7 +208,8 @@ class CrossTowerOps(object): ValueError: if per_device_value is not a PerDevice object. """ if not isinstance(per_device_value, value_lib.PerDevice): - raise ValueError("`per_device_value` must be a `PerDevice` object.") + per_device_value = _make_tensor_into_per_device(per_device_value) + if destinations is not None: validate_destinations(destinations) return self._reduce(aggregation, per_device_value, destinations) @@ -183,8 +224,9 @@ class CrossTowerOps(object): aggregation: Indicates how a variable will be aggregated. Accepted values are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`. value_destination_pairs: a list or a tuple of tuples of PerDevice objects - and destinations. If a destination is None, then the destinations - are set to match the devices of the input PerDevice object. + (or tensors with device set if there is one tower) and destinations. If + a destination is None, then the destinations are set to match the + devices of the input PerDevice object. Returns: a list of Mirrored objects. @@ -194,8 +236,11 @@ class CrossTowerOps(object): tuples of PerDevice objects and destinations """ if not _validate_value_destination_pairs(value_destination_pairs): - raise ValueError("`value_destination_pairs` must be a list or a tuple of " - "tuples of PerDevice objects and destinations") + # If the first element of each pair is a tensor, we try to turn it into a + # PerDevice object. + value_destination_pairs = _normalize_value_destination_pairs( + value_destination_pairs) + for _, d in value_destination_pairs: if d is not None: validate_destinations(d) @@ -756,7 +801,7 @@ class CollectiveAllReduce(CrossTowerOps): ) super(CollectiveAllReduce, self).__init__() - # TODO(yuefengz, tucker): is index slices supported by collective ops? + # TODO(yuefengz, tucker): is indexed slices supported by collective ops? def _reduce(self, aggregation, per_device_value, destinations): all_reduced = self._batch_all_reduce(aggregation, [per_device_value])[0] if destinations is None or _devices_match(per_device_value, destinations): @@ -768,8 +813,10 @@ class CollectiveAllReduce(CrossTowerOps): if d in all_reduced._index: index[d] = all_reduced._index[d] else: - with ops.device(d): + with ops.control_dependencies(list( + all_reduced._index.values())), ops.device(d): index[d] = array_ops.identity(list(all_reduced._index.values())[0]) + return value_lib.Mirrored(index) def _batch_reduce(self, aggregation, value_destination_pairs): |