diff options
author | Anjali Sridhar <anjalisridhar@google.com> | 2018-07-22 16:32:41 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-22 16:35:59 -0700 |
commit | 89e06304aad35bfb019a8c10f39fc1ead83e0f99 (patch) | |
tree | 6a26fb960a3a1870938e5f16d8316f51f2cb2d0c /tensorflow/contrib/distribute/python/values.py | |
parent | 012f97121441f936b5262b98e2ca488c0c92422f (diff) |
Add support for `is_tensor_like` property to DistributedValues and add support for calling `assign` on TowerLocalVariables.
PiperOrigin-RevId: 205595323
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/values.py | 33 |
1 files changed, 23 insertions, 10 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index 3162aebf5b..47dcf679c2 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -30,6 +30,7 @@ from tensorflow.contrib.distribute.python import prefetching_ops_v2 from tensorflow.python.eager import context from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops @@ -77,6 +78,13 @@ class DistributedValues(object): def devices(self): return list(self._index.keys()) + @property + def is_tensor_like(self): + for v in self._index.values(): + if not tensor_util.is_tensor(v): + return False + return True + def __str__(self): return "%s:%s" % (self.__class__.__name__, self._index) @@ -352,6 +360,7 @@ class MirroredVariable(DistributedVariable, Mirrored, return distribute_lib.get_distribution_strategy().update( self, f, *args, **kwargs) else: + _assert_tower_context() # We are calling an assign function on the mirrored variable in tower # context. # We reduce the value we want to assign/add/sub. More details about how we @@ -448,14 +457,7 @@ class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject): def restore(self, restored_tensors, restored_shapes): """Restore the same value into all variables.""" tensor, = restored_tensors - # To preserve the sum across save and restore, we have to divide the - # total across all devices when restoring a variable that was summed - # when saving. - if self._tower_local_variable.aggregation == vs.VariableAggregation.SUM: - tensor *= 1. / len(self._tower_local_variable.devices) - return control_flow_ops.group([ - _assign_on_device(d, v, tensor) - for d, v in six.iteritems(self._tower_local_variable._index)]) # pylint: disable=protected-access + return self._tower_local_variable.assign(tensor) def _assert_tower_context(): @@ -482,8 +484,19 @@ class TowerLocalVariable(DistributedVariable, PerDevice, return self.get().assign_add(*args, **kwargs) def assign(self, *args, **kwargs): - _assert_tower_context() - return self.get().assign(*args, **kwargs) + if distribute_lib.get_cross_tower_context(): + # To preserve the sum across save and restore, we have to divide the + # total across all devices when restoring a variable that was summed + # when saving. + tensor = args[0] + if self._aggregation == vs.VariableAggregation.SUM: + tensor *= 1. / len(self.devices) + return control_flow_ops.group( + [_assign_on_device(d, v, tensor) + for d, v in six.iteritems(self._index)]) + else: + _assert_tower_context() + return self.get().assign(*args, **kwargs) @property def aggregation(self): |