aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/values.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-29 10:17:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-29 10:22:42 -0700
commitaca93368a979419360c1fd84b53b1766b19ba81a (patch)
tree2312ef53a30251ec2f5538d43ba066550679f6d9 /tensorflow/contrib/distribute/python/values.py
parent8a22fa7037332fc6066459ce8c6fabcd77c6ece4 (diff)
Add new aggregation mode "ONLY_FIRST_TOWER" and use it for the global
step counter. This allows us to get rid of the increment_var() function and just use a standard assign_add(). PiperOrigin-RevId: 210743165
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r--tensorflow/contrib/distribute/python/values.py2
1 files changed, 2 insertions, 0 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index 3ccaa2690e..479b7f39d6 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -523,6 +523,8 @@ class TowerLocalVariable(DistributedVariable, PerDevice,
return self._aggregation
def _get_cross_tower(self):
+ if self._aggregation == vs.VariableAggregation.ONLY_FIRST_TOWER:
+ return self._primary_var
all_components = tuple(self._index.values())
# TODO(josh11b): Use a strategy-specific method.
total = math_ops.add_n(all_components)