aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/mirrored_strategy.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-24 09:12:59 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-24 09:18:02 -0700
commit6f00ab7b8f16f450c00375df271c45da4dc72be5 (patch)
tree25163f5e9d441de60df20f1e990758271cbcf316 /tensorflow/contrib/distribute/python/mirrored_strategy.py
parent247b81a7c47fe52a383c86a9a32efa536ead6fa6 (diff)
For ParameterServerStrategy, make sure to include the AggregatingVariable
wrapper for variables in collections instead of what it wraps. PiperOrigin-RevId: 210107528
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy.py')
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py6
1 files changed, 6 insertions, 0 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index ecaf60f350..e87b48ba41 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -276,6 +276,9 @@ def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs):
else:
result = values.MirroredVariable(index, index[devices[0]], aggregation)
+ # Add the wrapped variable to the requested collections.
+ # The handling of eager mode and the global step matches
+ # ResourceVariable._init_from_args().
if not context.executing_eagerly():
g = ops.get_default_graph()
# If "trainable" is True, next_creator() will add the member variables
@@ -289,6 +292,9 @@ def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs):
for v in index.values():
l.remove(v)
g.add_to_collections(collections, result)
+ elif ops.GraphKeys.GLOBAL_STEP in collections:
+ ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result)
+
return result