aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/values.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-20 20:49:54 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-20 20:53:29 -0700
commit7de22654844a41575a30cd1ce3a522abd0516fde (patch)
tree8e5cf95bb8404ed9f05bc510c9f3f25dda9653f5 /tensorflow/contrib/distribute/python/values.py
parent3f78692cd7b47e9276e8809dd891578758f2de13 (diff)
Correctly use the aggregation mode set for variables in
ParameterServerStrategy when using >1 device per machine. This means wrapping the variable instances returned in that case in a class that intercepts assign_*() method calls. PiperOrigin-RevId: 209533673
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r--tensorflow/contrib/distribute/python/values.py180
1 files changed, 157 insertions, 23 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index 8548a86421..a58bb3a849 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -308,26 +308,6 @@ class DistributedVariable(DistributedDelegate):
ops.register_dense_tensor_like_type(DistributedVariable)
-def _get_update_device():
- """Validate we are in update/update_non_slot() and return current device.
-
- This is used in MirroredVariable.assign* members, to make sure they
- are only called via an update method, to make sure all components of the
- variable are being updated in a consistent way.
-
- Returns:
- A string device.
-
- Raises:
- RuntimeError: If not in distribution.update()/.update_non_slot().
- """
- device = distribute_lib.get_update_device()
- if device is None:
- raise RuntimeError(
- "Use DistributionStrategy.update() to modify a MirroredVariable.")
- return device
-
-
class _MirroredSaveable(saver.BaseSaverBuilder.ResourceVariableSaveable):
"""Class for defining how to restore a MirroredVariable."""
@@ -366,13 +346,14 @@ class MirroredVariable(DistributedVariable, Mirrored,
f = kwargs.pop("f")
if distribution_strategy_context.get_cross_tower_context():
update_device = distribute_lib.get_update_device()
- # We are calling update on the mirrored variable in cross tower context.
if update_device is not None:
- # We are calling an assign function on the mirrored variable in cross
- # tower context.
+ # We are calling an assign function on the mirrored variable in an
+ # update context.
v = self.get(device=update_device)
return f(v, *args, **kwargs)
+ # We are calling assign on the mirrored variable in cross tower context,
+ # use update to update the variable.
return distribution_strategy_context.get_distribution_strategy().update(
self, f, *args, **kwargs)
else:
@@ -1057,3 +1038,156 @@ def value_container(val):
if container is not None:
return container
return val
+
+
+# TODO(josh11b): Descend from Variable.
+class AggregatingVariable(checkpointable.CheckpointableBase):
+ """A wrapper around a variable that aggregates updates across towers."""
+
+ def __init__(self, v, aggregation):
+ self._v = v
+ # TODO(josh11b): Set v._distributed_container?
+ # v._distributed_container = weakref.ref(self) # pylint: disable=protected-access
+ self._aggregation = aggregation
+
+ def get(self):
+ return self._v
+
+ def __getattr__(self, name):
+ return getattr(self._v, name)
+
+ def _assign_func(self, *args, **kwargs):
+ f = kwargs.pop("f")
+ if distribution_strategy_context.get_cross_tower_context():
+ update_device = distribute_lib.get_update_device()
+ if update_device is not None:
+ # We are calling an assign function in an update context.
+ return f(self._v, *args, **kwargs)
+
+ # We are calling an assign function in cross tower context, wrap it in an
+ # update call.
+ return distribution_strategy_context.get_distribution_strategy().update(
+ self, f, *args, **kwargs)
+ else:
+ assert distribution_strategy_context.get_tower_context()
+ # We are calling an assign function in tower context.
+ # We reduce the value we want to assign/add/sub. More details about how we
+ # handle the different use cases can be found in the _reduce method.
+ # We call the function with the reduced value.
+ if self._aggregation == vs.VariableAggregation.NONE:
+ raise ValueError("You must specify an aggregation method to update a "
+ "a variable in Tower Context.")
+
+ def merge_fn(strategy, value, *other_args, **other_kwargs):
+ return strategy.update(
+ self, f,
+ strategy.reduce(
+ aggregation=self._aggregation, value=value, destinations=self),
+ *other_args, **other_kwargs)
+
+ return distribution_strategy_context.get_tower_context().merge_call(
+ merge_fn, *args, **kwargs)
+
+ def assign_sub(self, *args, **kwargs):
+ assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
+ return self._assign_func(f=assign_sub_fn, *args, **kwargs)
+
+ def assign_add(self, *args, **kwargs):
+ assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
+ return self._assign_func(f=assign_add_fn, *args, **kwargs)
+
+ def assign(self, *args, **kwargs):
+ assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
+ return self._assign_func(f=assign_fn, *args, **kwargs)
+
+ @property
+ def aggregation(self):
+ return self._aggregation
+
+ @property
+ def name(self):
+ return self._v.name
+
+ @property
+ def dtype(self):
+ return self._v.dtype
+
+ # TODO(josh11b): Test saving & restoring.
+ def _gather_saveables_for_checkpoint(self):
+ return {checkpointable.VARIABLE_VALUE_KEY: self._v}
+
+ # pylint: disable=multiple-statements
+ def __add__(self, o): return self._v + o
+ def __radd__(self, o): return o + self._v
+ def __sub__(self, o): return self._v - o
+ def __rsub__(self, o): return o - self._v
+ def __mul__(self, o): return self._v * o
+ def __rmul__(self, o): return o * self._v
+ def __truediv__(self, o): return self._v / o
+ def __rtruediv__(self, o): return o / self._v
+ def __floordiv__(self, o): return self._v // o
+ def __rfloordiv__(self, o): return o // self._v
+ def __mod__(self, o): return self._v % o
+ def __rmod__(self, o): return o % self._v
+ def __lt__(self, o): return self._v < o
+ def __le__(self, o): return self._v <= o
+ def __gt__(self, o): return self._v > o
+ def __ge__(self, o): return self._v >= o
+ def __and__(self, o): return self._v & o
+ def __rand__(self, o): return o & self._v
+ def __or__(self, o): return self._v | o
+ def __ror__(self, o): return o | self._v
+ def __xor__(self, o): return self._v ^ o
+ def __rxor__(self, o): return o ^ self._v
+ def __getitem__(self, o): return self._v[o]
+ def __pow__(self, o, modulo=None): return pow(self._v, o, modulo)
+ def __rpow__(self, o): return pow(o, self._v)
+ def __invert__(self): return ~self._v
+ def __neg__(self): return -self._v
+ def __abs__(self): return abs(self._v)
+
+ def __div__(self, o):
+ try:
+ return self._v.__div__(o)
+ except AttributeError:
+ # See https://docs.python.org/3/library/constants.html#NotImplemented
+ return NotImplemented
+
+ def __rdiv__(self, o):
+ try:
+ return self._v.__rdiv__(o)
+ except AttributeError:
+ # See https://docs.python.org/3/library/constants.html#NotImplemented
+ return NotImplemented
+
+ def __matmul__(self, o):
+ try:
+ return self._v.__matmul__(o)
+ except AttributeError:
+ # See https://docs.python.org/3/library/constants.html#NotImplemented
+ return NotImplemented
+
+ def __rmatmul__(self, o):
+ try:
+ return self._v.__rmatmul__(o)
+ except AttributeError:
+ # See https://docs.python.org/3/library/constants.html#NotImplemented
+ return NotImplemented
+
+ def __str__(self):
+ return str(self._v)
+
+ def __repr__(self):
+ return repr(self._v)
+
+
+# Register a conversion function which reads the value of the variable,
+# allowing instances of the class to be used as tensors.
+def _tensor_conversion_aggregate(var, dtype=None, name=None, as_ref=False):
+ return ops.internal_convert_to_tensor(
+ var.get(), dtype=dtype, name=name, as_ref=as_ref)
+
+
+ops.register_tensor_conversion_function(
+ AggregatingVariable, _tensor_conversion_aggregate)
+ops.register_dense_tensor_like_type(AggregatingVariable)