diff options
author | Pavithra Vijay <psv@google.com> | 2018-06-29 18:02:18 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-29 18:05:25 -0700 |
commit | c290930ec1beacbcac414b43b3367dd44ffbd303 (patch) | |
tree | b1136e7c32718a6f1f9ebfde3073c88546078de6 /tensorflow/contrib/distribute/python/mirrored_strategy.py | |
parent | a520735d205ca5678fc8a371ea1add00413907fe (diff) |
Add `synchronization` and `aggregation` args to get_variable(). These args will be used for distributed variables.
Add Enum `VariableSynchronization` with values for `synchronization`: AUTO, UNREPLICATED, ON_WRITE, ON_READ
Add Enum `VariableAggregation` with values for `aggregation`: NONE, SUM, MEAN. Replace all the aggregation methods strings in distribution strategy to the enum values.
Update Mirrored strategy to use these parameters to decide on whether a variable should be Mirrored or TowerLocal.
Update different distribution strategy value types to use the `VariableAggregation` Enum
PiperOrigin-RevId: 202736077
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/mirrored_strategy.py | 47 |
1 files changed, 35 insertions, 12 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index d269bed1e5..14c02ab1ad 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -104,9 +104,32 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): colocate_with = kwargs.pop("colocate_with", None) devices = self._get_devices_from(colocate_with) - tower_local = kwargs.pop("tower_local_reduce_method", None) - if tower_local is not None: + # Get synchronization value + synchronization = kwargs.get( + "synchronization", variable_scope.VariableSynchronization.ON_WRITE) + if synchronization == variable_scope.VariableSynchronization.NONE: + raise ValueError("`NONE` variable synchronization mode is not " + "supported with `Mirrored` distribution strategy. Please" + " change the `synchronization` for variable: " + + kwargs["name"]) + elif synchronization == variable_scope.VariableSynchronization.ON_READ: + # Variables that are to be synced on read are tower local. + is_tower_local = True kwargs["trainable"] = False + elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE or + synchronization == variable_scope.VariableSynchronization.AUTO): + # `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`. + is_tower_local = False + else: + raise ValueError("Invalid variable synchronization mode: " + + synchronization + " for variable: " + kwargs["name"]) + + # Get aggregation value + aggregation = kwargs.pop("aggregation", + variable_scope.VariableAggregation.NONE) + if aggregation not in [a for a in variable_scope.VariableAggregation]: + raise ValueError("Invalid variable aggregation mode: " + aggregation + + " for variable: " + kwargs["name"]) # Ignore user-specified caching device, not needed for mirrored variables. kwargs.pop("caching_device", None) @@ -139,11 +162,11 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): assert not isinstance(v, values.DistributedVariable) index[d] = v - if tower_local is None: - result = values.MirroredVariable(index, index[devices[0]]) + if is_tower_local: + result = values.TowerLocalVariable(index, index[devices[0]], + aggregation) else: - result = values.TowerLocalVariable( - index, index[devices[0]], tower_local) + result = values.MirroredVariable(index, index[devices[0]], aggregation) if not context.executing_eagerly(): g = ops.get_default_graph() @@ -308,12 +331,12 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()) return self._cross_tower_ops - def _reduce(self, method_string, value, destinations): + def _reduce(self, aggregation, value, destinations): assert not isinstance(value, values.Mirrored) if not isinstance(value, values.PerDevice): if value == 0: return 0 - if method_string == "mean": + if aggregation == variable_scope.VariableAggregation.MEAN: return self._broadcast(value, destinations) cross_tower_ops_lib.validate_destinations(destinations) @@ -331,13 +354,13 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): value_updates[d] = array_ops.identity(value) return values.Mirrored(value_updates) raise ValueError("A non PerDevice value cannot be reduced with the given " - "method_string.") + "aggregation.") return self._get_cross_tower_ops().reduce( - method_string, value, destinations=destinations) + aggregation, value, destinations=destinations) - def _batch_reduce(self, method_string, value_destination_pairs): - return self._get_cross_tower_ops().batch_reduce(method_string, + def _batch_reduce(self, aggregation, value_destination_pairs): + return self._get_cross_tower_ops().batch_reduce(aggregation, value_destination_pairs) def _update(self, var, fn, *args, **kwargs): |