aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/mirrored_strategy.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-12 07:02:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-12 07:05:39 -0700
commit8e7ae1c8c78cebc7cc98cb99b3f8a3e8a415b5ff (patch)
treeae288a1a1e379e15fa7ef2057af49631e709140f /tensorflow/contrib/distribute/python/mirrored_strategy.py
parentcba0c951587bbf93144e4821013dbf5ae6cb5efe (diff)
Automated g4 rollback of changelist 197218170
PiperOrigin-RevId: 200209039
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy.py')
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py8
1 files changed, 8 insertions, 0 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index cef0a2907b..403e47d94f 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -31,6 +31,7 @@ from tensorflow.python.eager import tape
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.training import coordinator
from tensorflow.python.training import device_util
@@ -343,6 +344,13 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
**values.select_device_mirrored(d, kwargs))
return values.regroup(updates, values.Mirrored)
+ def read_var(self, tower_local_var):
+ """Read the aggregate value of a tower-local variable."""
+ if isinstance(tower_local_var, values.TowerLocalVariable):
+ return math_ops.add_n(self.unwrap(tower_local_var))
+ assert isinstance(tower_local_var, values.Mirrored)
+ return tower_local_var.get()
+
def _fetch(self, val, destination, fn):
"""Return a copy of `val` or `fn(val)` on `destination`."""
if isinstance(val, values.TowerLocalVariable):