diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-20 20:49:54 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-20 20:53:29 -0700 |
commit | 7de22654844a41575a30cd1ce3a522abd0516fde (patch) | |
tree | 8e5cf95bb8404ed9f05bc510c9f3f25dda9653f5 /tensorflow/contrib/distribute/python/values.py | |
parent | 3f78692cd7b47e9276e8809dd891578758f2de13 (diff) |
Correctly use the aggregation mode set for variables in
ParameterServerStrategy when using >1 device per machine. This means
wrapping the variable instances returned in that case in a class
that intercepts assign_*() method calls.
PiperOrigin-RevId: 209533673
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/values.py | 180 |
1 files changed, 157 insertions, 23 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index 8548a86421..a58bb3a849 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -308,26 +308,6 @@ class DistributedVariable(DistributedDelegate): ops.register_dense_tensor_like_type(DistributedVariable) -def _get_update_device(): - """Validate we are in update/update_non_slot() and return current device. - - This is used in MirroredVariable.assign* members, to make sure they - are only called via an update method, to make sure all components of the - variable are being updated in a consistent way. - - Returns: - A string device. - - Raises: - RuntimeError: If not in distribution.update()/.update_non_slot(). - """ - device = distribute_lib.get_update_device() - if device is None: - raise RuntimeError( - "Use DistributionStrategy.update() to modify a MirroredVariable.") - return device - - class _MirroredSaveable(saver.BaseSaverBuilder.ResourceVariableSaveable): """Class for defining how to restore a MirroredVariable.""" @@ -366,13 +346,14 @@ class MirroredVariable(DistributedVariable, Mirrored, f = kwargs.pop("f") if distribution_strategy_context.get_cross_tower_context(): update_device = distribute_lib.get_update_device() - # We are calling update on the mirrored variable in cross tower context. if update_device is not None: - # We are calling an assign function on the mirrored variable in cross - # tower context. + # We are calling an assign function on the mirrored variable in an + # update context. v = self.get(device=update_device) return f(v, *args, **kwargs) + # We are calling assign on the mirrored variable in cross tower context, + # use update to update the variable. return distribution_strategy_context.get_distribution_strategy().update( self, f, *args, **kwargs) else: @@ -1057,3 +1038,156 @@ def value_container(val): if container is not None: return container return val + + +# TODO(josh11b): Descend from Variable. +class AggregatingVariable(checkpointable.CheckpointableBase): + """A wrapper around a variable that aggregates updates across towers.""" + + def __init__(self, v, aggregation): + self._v = v + # TODO(josh11b): Set v._distributed_container? + # v._distributed_container = weakref.ref(self) # pylint: disable=protected-access + self._aggregation = aggregation + + def get(self): + return self._v + + def __getattr__(self, name): + return getattr(self._v, name) + + def _assign_func(self, *args, **kwargs): + f = kwargs.pop("f") + if distribution_strategy_context.get_cross_tower_context(): + update_device = distribute_lib.get_update_device() + if update_device is not None: + # We are calling an assign function in an update context. + return f(self._v, *args, **kwargs) + + # We are calling an assign function in cross tower context, wrap it in an + # update call. + return distribution_strategy_context.get_distribution_strategy().update( + self, f, *args, **kwargs) + else: + assert distribution_strategy_context.get_tower_context() + # We are calling an assign function in tower context. + # We reduce the value we want to assign/add/sub. More details about how we + # handle the different use cases can be found in the _reduce method. + # We call the function with the reduced value. + if self._aggregation == vs.VariableAggregation.NONE: + raise ValueError("You must specify an aggregation method to update a " + "a variable in Tower Context.") + + def merge_fn(strategy, value, *other_args, **other_kwargs): + return strategy.update( + self, f, + strategy.reduce( + aggregation=self._aggregation, value=value, destinations=self), + *other_args, **other_kwargs) + + return distribution_strategy_context.get_tower_context().merge_call( + merge_fn, *args, **kwargs) + + def assign_sub(self, *args, **kwargs): + assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw) + return self._assign_func(f=assign_sub_fn, *args, **kwargs) + + def assign_add(self, *args, **kwargs): + assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw) + return self._assign_func(f=assign_add_fn, *args, **kwargs) + + def assign(self, *args, **kwargs): + assign_fn = lambda var, *a, **kw: var.assign(*a, **kw) + return self._assign_func(f=assign_fn, *args, **kwargs) + + @property + def aggregation(self): + return self._aggregation + + @property + def name(self): + return self._v.name + + @property + def dtype(self): + return self._v.dtype + + # TODO(josh11b): Test saving & restoring. + def _gather_saveables_for_checkpoint(self): + return {checkpointable.VARIABLE_VALUE_KEY: self._v} + + # pylint: disable=multiple-statements + def __add__(self, o): return self._v + o + def __radd__(self, o): return o + self._v + def __sub__(self, o): return self._v - o + def __rsub__(self, o): return o - self._v + def __mul__(self, o): return self._v * o + def __rmul__(self, o): return o * self._v + def __truediv__(self, o): return self._v / o + def __rtruediv__(self, o): return o / self._v + def __floordiv__(self, o): return self._v // o + def __rfloordiv__(self, o): return o // self._v + def __mod__(self, o): return self._v % o + def __rmod__(self, o): return o % self._v + def __lt__(self, o): return self._v < o + def __le__(self, o): return self._v <= o + def __gt__(self, o): return self._v > o + def __ge__(self, o): return self._v >= o + def __and__(self, o): return self._v & o + def __rand__(self, o): return o & self._v + def __or__(self, o): return self._v | o + def __ror__(self, o): return o | self._v + def __xor__(self, o): return self._v ^ o + def __rxor__(self, o): return o ^ self._v + def __getitem__(self, o): return self._v[o] + def __pow__(self, o, modulo=None): return pow(self._v, o, modulo) + def __rpow__(self, o): return pow(o, self._v) + def __invert__(self): return ~self._v + def __neg__(self): return -self._v + def __abs__(self): return abs(self._v) + + def __div__(self, o): + try: + return self._v.__div__(o) + except AttributeError: + # See https://docs.python.org/3/library/constants.html#NotImplemented + return NotImplemented + + def __rdiv__(self, o): + try: + return self._v.__rdiv__(o) + except AttributeError: + # See https://docs.python.org/3/library/constants.html#NotImplemented + return NotImplemented + + def __matmul__(self, o): + try: + return self._v.__matmul__(o) + except AttributeError: + # See https://docs.python.org/3/library/constants.html#NotImplemented + return NotImplemented + + def __rmatmul__(self, o): + try: + return self._v.__rmatmul__(o) + except AttributeError: + # See https://docs.python.org/3/library/constants.html#NotImplemented + return NotImplemented + + def __str__(self): + return str(self._v) + + def __repr__(self): + return repr(self._v) + + +# Register a conversion function which reads the value of the variable, +# allowing instances of the class to be used as tensors. +def _tensor_conversion_aggregate(var, dtype=None, name=None, as_ref=False): + return ops.internal_convert_to_tensor( + var.get(), dtype=dtype, name=name, as_ref=as_ref) + + +ops.register_tensor_conversion_function( + AggregatingVariable, _tensor_conversion_aggregate) +ops.register_dense_tensor_like_type(AggregatingVariable) |