aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/values.py
diff options
context:
space:
mode:
authorGravatar Anjali Sridhar <anjalisridhar@google.com>2018-06-26 11:25:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-26 11:28:45 -0700
commitbfda539bef38845809e3b0c5930458dc500d505d (patch)
treeca88e6879357c02b08263cf81e197a8dc47efae1 /tensorflow/contrib/distribute/python/values.py
parentd10213099df42d7138dd7479264e4c987a3d870f (diff)
Enable assign, assign_add and assign_sub to be called on Mirrored Variables in cross tower and tower context.
PiperOrigin-RevId: 202162272
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r--tensorflow/contrib/distribute/python/values.py89
1 files changed, 63 insertions, 26 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index 9a48928a95..ce95b718f6 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -23,7 +23,6 @@ from __future__ import print_function
import collections
import weakref
-
import six
from tensorflow.contrib.distribute.python import input_ops
@@ -34,6 +33,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
from tensorflow.python.training import device_util
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import saver
@@ -251,21 +251,6 @@ class DistributedVariable(DistributedDelegate):
ops.register_dense_tensor_like_type(DistributedVariable)
-class _MirroredSaveable(saver.BaseSaverBuilder.ResourceVariableSaveable):
- """Class for defining how to restore a MirroredVariable."""
-
- def __init__(self, mirrored_variable, primary_variable, name):
- self._mirrored_variable = mirrored_variable
- super(_MirroredSaveable, self).__init__(primary_variable, "", name)
-
- def restore(self, restored_tensors, restored_shapes):
- """Restore the same value into all variables."""
- tensor, = restored_tensors
- return control_flow_ops.group([
- _assign_on_device(d, v, tensor)
- for d, v in six.iteritems(self._mirrored_variable._index)]) # pylint: disable=protected-access
-
-
def _get_update_device():
"""Validate we are in update/update_non_slot() and return current device.
@@ -286,30 +271,82 @@ def _get_update_device():
return device
+class _MirroredSaveable(saver.BaseSaverBuilder.ResourceVariableSaveable):
+ """Class for defining how to restore a MirroredVariable."""
+
+ def __init__(self, mirrored_variable, primary_variable, name):
+ self._mirrored_variable = mirrored_variable
+ super(_MirroredSaveable, self).__init__(primary_variable, "", name)
+
+ def restore(self, restored_tensors, restored_shapes):
+ """Restore the same value into all variables."""
+ tensor, = restored_tensors
+ return control_flow_ops.group([
+ _assign_on_device(d, v, tensor)
+ for d, v in six.iteritems(self._mirrored_variable._index)]) # pylint: disable=protected-access
+
+
class MirroredVariable(DistributedVariable, Mirrored,
checkpointable.CheckpointableBase):
"""Holds a map from device to variables whose values are kept in sync."""
- def __init__(self, index, primary_var):
+ def __init__(self, index, primary_var, aggregation_method=None):
+ # Use a weakref to make it easy to map from the contained values
+ # to the container without introducing a reference cycle.
+ for v in six.itervalues(index):
+ v._mirrored_container = weakref.ref(self) # pylint: disable=protected-access
self._primary_var = primary_var
+ self._aggregation_method = aggregation_method
super(MirroredVariable, self).__init__(index)
- # We use _get_update_device() for the assign* methods to enforce
- # that we are in an update() function. The arguments to update() are
- # automatically unwrapped so the update() function would normally
- # see regular variables, not MirroredVariables. However, the update
- # function can still operate on wrapped MirroredVariables through
- # object members, captured arguments, etc. This is more likely in an
+ # The arguments to update() are automatically unwrapped so the update()
+ # function would normally see regular variables, not MirroredVariables.
+ # However, the update function can still operate on wrapped MirroredVariables
+ # through object members, captured arguments, etc. This is more likely in an
# update_non_slot() function (like OptimizerV2._finish), which can
# update several non-slot variables in one call.
+ def _assign_func(self, *args, **kwargs):
+ f = kwargs.pop("f")
+ if distribute_lib.get_cross_tower_context():
+ update_device = distribute_lib.get_update_device()
+ # We are calling update on the mirrored variable in cross tower context.
+ if update_device is not None:
+ # We are calling an assign function on the mirrored variable in cross
+ # tower context.
+ v = self.get(device=update_device)
+ return f(v, *args, **kwargs)
+
+ return distribute_lib.get_distribution_strategy().update(
+ self, f, *args, **kwargs)
+ else:
+ # We are calling an assign function on the mirrored variable in tower
+ # context.
+ # We reduce the value we want to assign/add/sub. More details about how we
+ # handle the different use cases can be found in the _reduce method.
+ # We call the function on each of the mirrored variables with the reduced
+ # value.
+ if not self._aggregation_method:
+ raise ValueError("You must specify an aggregation method to update a "
+ "MirroredVariable in Tower Context.")
+
+ def merge_fn(strategy, value):
+ return strategy.update(self,
+ f,
+ strategy.reduce(
+ method_string=self._aggregation_method,
+ value=value,
+ destinations=self))
+ return distribute_lib.get_tower_context().merge_call(merge_fn, *args,
+ **kwargs)
+
def assign_sub(self, *args, **kwargs):
- return self.get(device=_get_update_device()).assign_sub(*args, **kwargs)
+ return self._assign_func(f=state_ops.assign_sub, *args, **kwargs)
def assign_add(self, *args, **kwargs):
- return self.get(device=_get_update_device()).assign_add(*args, **kwargs)
+ return self._assign_func(f=state_ops.assign_add, *args, **kwargs)
def assign(self, *args, **kwargs):
- return self.get(device=_get_update_device()).assign(*args, **kwargs)
+ return self._assign_func(f=state_ops.assign, *args, **kwargs)
def _get_cross_tower(self):
device = device_util.canonicalize(device_util.current())