From c290930ec1beacbcac414b43b3367dd44ffbd303 Mon Sep 17 00:00:00 2001 From: Pavithra Vijay Date: Fri, 29 Jun 2018 18:02:18 -0700 Subject: Add `synchronization` and `aggregation` args to get_variable(). These args will be used for distributed variables. Add Enum `VariableSynchronization` with values for `synchronization`: AUTO, UNREPLICATED, ON_WRITE, ON_READ Add Enum `VariableAggregation` with values for `aggregation`: NONE, SUM, MEAN. Replace all the aggregation methods strings in distribution strategy to the enum values. Update Mirrored strategy to use these parameters to decide on whether a variable should be Mirrored or TowerLocal. Update different distribution strategy value types to use the `VariableAggregation` Enum PiperOrigin-RevId: 202736077 --- .../contrib/distribute/python/mirrored_strategy.py | 47 ++++++++++++++++------ 1 file changed, 35 insertions(+), 12 deletions(-) (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy.py') diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index d269bed1e5..14c02ab1ad 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -104,9 +104,32 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): colocate_with = kwargs.pop("colocate_with", None) devices = self._get_devices_from(colocate_with) - tower_local = kwargs.pop("tower_local_reduce_method", None) - if tower_local is not None: + # Get synchronization value + synchronization = kwargs.get( + "synchronization", variable_scope.VariableSynchronization.ON_WRITE) + if synchronization == variable_scope.VariableSynchronization.NONE: + raise ValueError("`NONE` variable synchronization mode is not " + "supported with `Mirrored` distribution strategy. Please" + " change the `synchronization` for variable: " + + kwargs["name"]) + elif synchronization == variable_scope.VariableSynchronization.ON_READ: + # Variables that are to be synced on read are tower local. + is_tower_local = True kwargs["trainable"] = False + elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE or + synchronization == variable_scope.VariableSynchronization.AUTO): + # `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`. + is_tower_local = False + else: + raise ValueError("Invalid variable synchronization mode: " + + synchronization + " for variable: " + kwargs["name"]) + + # Get aggregation value + aggregation = kwargs.pop("aggregation", + variable_scope.VariableAggregation.NONE) + if aggregation not in [a for a in variable_scope.VariableAggregation]: + raise ValueError("Invalid variable aggregation mode: " + aggregation + + " for variable: " + kwargs["name"]) # Ignore user-specified caching device, not needed for mirrored variables. kwargs.pop("caching_device", None) @@ -139,11 +162,11 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): assert not isinstance(v, values.DistributedVariable) index[d] = v - if tower_local is None: - result = values.MirroredVariable(index, index[devices[0]]) + if is_tower_local: + result = values.TowerLocalVariable(index, index[devices[0]], + aggregation) else: - result = values.TowerLocalVariable( - index, index[devices[0]], tower_local) + result = values.MirroredVariable(index, index[devices[0]], aggregation) if not context.executing_eagerly(): g = ops.get_default_graph() @@ -308,12 +331,12 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()) return self._cross_tower_ops - def _reduce(self, method_string, value, destinations): + def _reduce(self, aggregation, value, destinations): assert not isinstance(value, values.Mirrored) if not isinstance(value, values.PerDevice): if value == 0: return 0 - if method_string == "mean": + if aggregation == variable_scope.VariableAggregation.MEAN: return self._broadcast(value, destinations) cross_tower_ops_lib.validate_destinations(destinations) @@ -331,13 +354,13 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): value_updates[d] = array_ops.identity(value) return values.Mirrored(value_updates) raise ValueError("A non PerDevice value cannot be reduced with the given " - "method_string.") + "aggregation.") return self._get_cross_tower_ops().reduce( - method_string, value, destinations=destinations) + aggregation, value, destinations=destinations) - def _batch_reduce(self, method_string, value_destination_pairs): - return self._get_cross_tower_ops().batch_reduce(method_string, + def _batch_reduce(self, aggregation, value_destination_pairs): + return self._get_cross_tower_ops().batch_reduce(aggregation, value_destination_pairs) def _update(self, var, fn, *args, **kwargs): -- cgit v1.2.3