aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/values.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-07 13:24:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-07 17:05:34 -0700
commitbd8d7440d7121dc1e92c4794ca1d18d0e9eb0a17 (patch)
treebced7f226f01f8ff4fab8258373d47e504dc8647 /tensorflow/contrib/distribute/python/values.py
parent914c971c7b690661754e83549325c5deadd9e62d (diff)
Fixes for accessing variables with a MirroredStrategy in a
cross-tower context: * only provide read-only access to variables via get() * don't fail if use the variable isn't copied to the current device in get() * make _as_graph_element() return the aggregate value for tower-local variables (instead of the incorrect previous behavior of returning the primary) PiperOrigin-RevId: 195711474
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r--tensorflow/contrib/distribute/python/values.py44
1 files changed, 37 insertions, 7 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index b04734f1a3..759f3c3599 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -34,6 +34,7 @@ from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.training import checkpointable
from tensorflow.python.training import device_util
from tensorflow.python.training import distribute as distribute_lib
@@ -60,7 +61,7 @@ class DistributedValues(object):
else:
device = distribute_lib.get_update_device()
if device is None:
- device = device_util.current()
+ return self._get_cross_tower()
device = device_util.canonicalize(device)
try:
return self._index[device]
@@ -231,12 +232,6 @@ class DistributedVariable(DistributedDelegate):
self._primary_var.op.type)
return self.get().op
- def _as_graph_element(self):
- # pylint: disable=protected-access
- if distribute_lib.get_cross_tower_context():
- return self._primary_var._as_graph_element()
- return self.get()._as_graph_element()
-
def _should_act_as_resource_variable(self):
"""Pass resource_variable_ops.is_resource_variable check."""
pass
@@ -320,6 +315,18 @@ class MirroredVariable(DistributedVariable, Mirrored,
def assign(self, *args, **kwargs):
return self.get(device=_get_update_device()).assign(*args, **kwargs)
+ def _get_cross_tower(self):
+ device = device_util.canonicalize(device_util.current())
+ if device in self._index:
+ return array_ops.identity(self._index[device])
+ return array_ops.identity(self._primary_var)
+
+ def _as_graph_element(self):
+ # pylint: disable=protected-access
+ if distribute_lib.get_cross_tower_context():
+ return self._primary_var._as_graph_element()
+ return self.get()._as_graph_element()
+
def _gather_saveables_for_checkpoint(self):
"""Overrides CheckpointableBase method.
@@ -364,6 +371,12 @@ class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject):
for d, v in six.iteritems(self._tower_local_variable._index)]) # pylint: disable=protected-access
+def _assert_tower_context():
+ if not distribute_lib.get_tower_context():
+ raise RuntimeError(
+ "Tower-local variables may only be assigned in a tower context.")
+
+
class TowerLocalVariable(DistributedVariable, PerDevice,
checkpointable.CheckpointableBase):
"""Holds a map from device to variables whose values are reduced on save."""
@@ -374,18 +387,35 @@ class TowerLocalVariable(DistributedVariable, PerDevice,
super(TowerLocalVariable, self).__init__(index)
def assign_sub(self, *args, **kwargs):
+ _assert_tower_context()
return self.get().assign_sub(*args, **kwargs)
def assign_add(self, *args, **kwargs):
+ _assert_tower_context()
return self.get().assign_add(*args, **kwargs)
def assign(self, *args, **kwargs):
+ _assert_tower_context()
return self.get().assign(*args, **kwargs)
@property
def reduce_method(self):
return self._reduce_method
+ def _get_cross_tower(self):
+ all_components = tuple(self._index.values())
+ # TODO(josh11b): Use a strategy-specific method.
+ total = math_ops.add_n(all_components)
+ if self._reduce_method == "mean":
+ return total * (1./ len(all_components))
+ return total
+
+ def _as_graph_element(self):
+ # pylint: disable=protected-access
+ if distribute_lib.get_cross_tower_context():
+ return self._get_cross_tower()
+ return self.get()._as_graph_element()
+
def _gather_saveables_for_checkpoint(self):
"""Overrides CheckpointableBase method.