aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/mirrored_strategy.py
diff options
context:
space:
mode:
authorGravatar Pavithra Vijay <psv@google.com>2018-06-29 18:02:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-29 18:05:25 -0700
commitc290930ec1beacbcac414b43b3367dd44ffbd303 (patch)
treeb1136e7c32718a6f1f9ebfde3073c88546078de6 /tensorflow/contrib/distribute/python/mirrored_strategy.py
parenta520735d205ca5678fc8a371ea1add00413907fe (diff)
Add `synchronization` and `aggregation` args to get_variable(). These args will be used for distributed variables.
Add Enum `VariableSynchronization` with values for `synchronization`: AUTO, UNREPLICATED, ON_WRITE, ON_READ Add Enum `VariableAggregation` with values for `aggregation`: NONE, SUM, MEAN. Replace all the aggregation methods strings in distribution strategy to the enum values. Update Mirrored strategy to use these parameters to decide on whether a variable should be Mirrored or TowerLocal. Update different distribution strategy value types to use the `VariableAggregation` Enum PiperOrigin-RevId: 202736077
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy.py')
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py47
1 files changed, 35 insertions, 12 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index d269bed1e5..14c02ab1ad 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -104,9 +104,32 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
colocate_with = kwargs.pop("colocate_with", None)
devices = self._get_devices_from(colocate_with)
- tower_local = kwargs.pop("tower_local_reduce_method", None)
- if tower_local is not None:
+ # Get synchronization value
+ synchronization = kwargs.get(
+ "synchronization", variable_scope.VariableSynchronization.ON_WRITE)
+ if synchronization == variable_scope.VariableSynchronization.NONE:
+ raise ValueError("`NONE` variable synchronization mode is not "
+ "supported with `Mirrored` distribution strategy. Please"
+ " change the `synchronization` for variable: " +
+ kwargs["name"])
+ elif synchronization == variable_scope.VariableSynchronization.ON_READ:
+ # Variables that are to be synced on read are tower local.
+ is_tower_local = True
kwargs["trainable"] = False
+ elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE or
+ synchronization == variable_scope.VariableSynchronization.AUTO):
+ # `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`.
+ is_tower_local = False
+ else:
+ raise ValueError("Invalid variable synchronization mode: " +
+ synchronization + " for variable: " + kwargs["name"])
+
+ # Get aggregation value
+ aggregation = kwargs.pop("aggregation",
+ variable_scope.VariableAggregation.NONE)
+ if aggregation not in [a for a in variable_scope.VariableAggregation]:
+ raise ValueError("Invalid variable aggregation mode: " + aggregation +
+ " for variable: " + kwargs["name"])
# Ignore user-specified caching device, not needed for mirrored variables.
kwargs.pop("caching_device", None)
@@ -139,11 +162,11 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
assert not isinstance(v, values.DistributedVariable)
index[d] = v
- if tower_local is None:
- result = values.MirroredVariable(index, index[devices[0]])
+ if is_tower_local:
+ result = values.TowerLocalVariable(index, index[devices[0]],
+ aggregation)
else:
- result = values.TowerLocalVariable(
- index, index[devices[0]], tower_local)
+ result = values.MirroredVariable(index, index[devices[0]], aggregation)
if not context.executing_eagerly():
g = ops.get_default_graph()
@@ -308,12 +331,12 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps())
return self._cross_tower_ops
- def _reduce(self, method_string, value, destinations):
+ def _reduce(self, aggregation, value, destinations):
assert not isinstance(value, values.Mirrored)
if not isinstance(value, values.PerDevice):
if value == 0:
return 0
- if method_string == "mean":
+ if aggregation == variable_scope.VariableAggregation.MEAN:
return self._broadcast(value, destinations)
cross_tower_ops_lib.validate_destinations(destinations)
@@ -331,13 +354,13 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
value_updates[d] = array_ops.identity(value)
return values.Mirrored(value_updates)
raise ValueError("A non PerDevice value cannot be reduced with the given "
- "method_string.")
+ "aggregation.")
return self._get_cross_tower_ops().reduce(
- method_string, value, destinations=destinations)
+ aggregation, value, destinations=destinations)
- def _batch_reduce(self, method_string, value_destination_pairs):
- return self._get_cross_tower_ops().batch_reduce(method_string,
+ def _batch_reduce(self, aggregation, value_destination_pairs):
+ return self._get_cross_tower_ops().batch_reduce(aggregation,
value_destination_pairs)
def _update(self, var, fn, *args, **kwargs):