aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/values.py
diff options
context:
space:
mode:
authorGravatar Anjali Sridhar <anjalisridhar@google.com>2018-07-22 16:32:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-22 16:35:59 -0700
commit89e06304aad35bfb019a8c10f39fc1ead83e0f99 (patch)
tree6a26fb960a3a1870938e5f16d8316f51f2cb2d0c /tensorflow/contrib/distribute/python/values.py
parent012f97121441f936b5262b98e2ca488c0c92422f (diff)
Add support for `is_tensor_like` property to DistributedValues and add support for calling `assign` on TowerLocalVariables.
PiperOrigin-RevId: 205595323
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r--tensorflow/contrib/distribute/python/values.py33
1 files changed, 23 insertions, 10 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index 3162aebf5b..47dcf679c2 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -30,6 +30,7 @@ from tensorflow.contrib.distribute.python import prefetching_ops_v2
from tensorflow.python.eager import context
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
@@ -77,6 +78,13 @@ class DistributedValues(object):
def devices(self):
return list(self._index.keys())
+ @property
+ def is_tensor_like(self):
+ for v in self._index.values():
+ if not tensor_util.is_tensor(v):
+ return False
+ return True
+
def __str__(self):
return "%s:%s" % (self.__class__.__name__, self._index)
@@ -352,6 +360,7 @@ class MirroredVariable(DistributedVariable, Mirrored,
return distribute_lib.get_distribution_strategy().update(
self, f, *args, **kwargs)
else:
+ _assert_tower_context()
# We are calling an assign function on the mirrored variable in tower
# context.
# We reduce the value we want to assign/add/sub. More details about how we
@@ -448,14 +457,7 @@ class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject):
def restore(self, restored_tensors, restored_shapes):
"""Restore the same value into all variables."""
tensor, = restored_tensors
- # To preserve the sum across save and restore, we have to divide the
- # total across all devices when restoring a variable that was summed
- # when saving.
- if self._tower_local_variable.aggregation == vs.VariableAggregation.SUM:
- tensor *= 1. / len(self._tower_local_variable.devices)
- return control_flow_ops.group([
- _assign_on_device(d, v, tensor)
- for d, v in six.iteritems(self._tower_local_variable._index)]) # pylint: disable=protected-access
+ return self._tower_local_variable.assign(tensor)
def _assert_tower_context():
@@ -482,8 +484,19 @@ class TowerLocalVariable(DistributedVariable, PerDevice,
return self.get().assign_add(*args, **kwargs)
def assign(self, *args, **kwargs):
- _assert_tower_context()
- return self.get().assign(*args, **kwargs)
+ if distribute_lib.get_cross_tower_context():
+ # To preserve the sum across save and restore, we have to divide the
+ # total across all devices when restoring a variable that was summed
+ # when saving.
+ tensor = args[0]
+ if self._aggregation == vs.VariableAggregation.SUM:
+ tensor *= 1. / len(self.devices)
+ return control_flow_ops.group(
+ [_assign_on_device(d, v, tensor)
+ for d, v in six.iteritems(self._index)])
+ else:
+ _assert_tower_context()
+ return self.get().assign(*args, **kwargs)
@property
def aggregation(self):