diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-05-07 13:24:12 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-07 17:05:34 -0700 |
commit | bd8d7440d7121dc1e92c4794ca1d18d0e9eb0a17 (patch) | |
tree | bced7f226f01f8ff4fab8258373d47e504dc8647 /tensorflow/contrib/distribute/python/values.py | |
parent | 914c971c7b690661754e83549325c5deadd9e62d (diff) |
Fixes for accessing variables with a MirroredStrategy in a
cross-tower context:
* only provide read-only access to variables via get()
* don't fail if use the variable isn't copied to the current device in
get()
* make _as_graph_element() return the aggregate value for tower-local
variables (instead of the incorrect previous behavior of returning
the primary)
PiperOrigin-RevId: 195711474
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/values.py | 44 |
1 files changed, 37 insertions, 7 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index b04734f1a3..759f3c3599 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -34,6 +34,7 @@ from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops from tensorflow.python.training import checkpointable from tensorflow.python.training import device_util from tensorflow.python.training import distribute as distribute_lib @@ -60,7 +61,7 @@ class DistributedValues(object): else: device = distribute_lib.get_update_device() if device is None: - device = device_util.current() + return self._get_cross_tower() device = device_util.canonicalize(device) try: return self._index[device] @@ -231,12 +232,6 @@ class DistributedVariable(DistributedDelegate): self._primary_var.op.type) return self.get().op - def _as_graph_element(self): - # pylint: disable=protected-access - if distribute_lib.get_cross_tower_context(): - return self._primary_var._as_graph_element() - return self.get()._as_graph_element() - def _should_act_as_resource_variable(self): """Pass resource_variable_ops.is_resource_variable check.""" pass @@ -320,6 +315,18 @@ class MirroredVariable(DistributedVariable, Mirrored, def assign(self, *args, **kwargs): return self.get(device=_get_update_device()).assign(*args, **kwargs) + def _get_cross_tower(self): + device = device_util.canonicalize(device_util.current()) + if device in self._index: + return array_ops.identity(self._index[device]) + return array_ops.identity(self._primary_var) + + def _as_graph_element(self): + # pylint: disable=protected-access + if distribute_lib.get_cross_tower_context(): + return self._primary_var._as_graph_element() + return self.get()._as_graph_element() + def _gather_saveables_for_checkpoint(self): """Overrides CheckpointableBase method. @@ -364,6 +371,12 @@ class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject): for d, v in six.iteritems(self._tower_local_variable._index)]) # pylint: disable=protected-access +def _assert_tower_context(): + if not distribute_lib.get_tower_context(): + raise RuntimeError( + "Tower-local variables may only be assigned in a tower context.") + + class TowerLocalVariable(DistributedVariable, PerDevice, checkpointable.CheckpointableBase): """Holds a map from device to variables whose values are reduced on save.""" @@ -374,18 +387,35 @@ class TowerLocalVariable(DistributedVariable, PerDevice, super(TowerLocalVariable, self).__init__(index) def assign_sub(self, *args, **kwargs): + _assert_tower_context() return self.get().assign_sub(*args, **kwargs) def assign_add(self, *args, **kwargs): + _assert_tower_context() return self.get().assign_add(*args, **kwargs) def assign(self, *args, **kwargs): + _assert_tower_context() return self.get().assign(*args, **kwargs) @property def reduce_method(self): return self._reduce_method + def _get_cross_tower(self): + all_components = tuple(self._index.values()) + # TODO(josh11b): Use a strategy-specific method. + total = math_ops.add_n(all_components) + if self._reduce_method == "mean": + return total * (1./ len(all_components)) + return total + + def _as_graph_element(self): + # pylint: disable=protected-access + if distribute_lib.get_cross_tower_context(): + return self._get_cross_tower() + return self.get()._as_graph_element() + def _gather_saveables_for_checkpoint(self): """Overrides CheckpointableBase method. |