aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/mirrored_strategy.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy.py')
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py23
1 files changed, 14 insertions, 9 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index a32424b316..0f82508428 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -293,7 +293,8 @@ def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs):
collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
for v in index.values():
- l.remove(v)
+ if v in l:
+ l.remove(v)
g.add_to_collections(collections, result)
elif ops.GraphKeys.GLOBAL_STEP in collections:
ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result)
@@ -461,16 +462,20 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
# name as the absolute name of the variable.
kwargs["name"] = "%s/replica_%d/" % (var0name, i)
# Initialize replicas with the same value:
- if context.executing_eagerly():
- kwargs["initial_value"] = array_ops.identity(
- index[devices[0]].value())
- else:
- def initial_value_fn(device=d):
+ def initial_value_fn(device=d):
+ if context.executing_eagerly():
+ init_value = index[devices[0]].value()
+ return array_ops.identity(init_value)
+ else:
with ops.device(device):
- return array_ops.identity(index[devices[0]].initial_value)
- kwargs["initial_value"] = initial_value_fn
+ init_value = index[devices[0]].initial_value
+ return array_ops.identity(init_value)
+ kwargs["initial_value"] = initial_value_fn
with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
- v = next_creator(*args, **kwargs)
+ # Don't record operations (e.g. other variable reads) during
+ # variable creation.
+ with tape.stop_recording():
+ v = next_creator(*args, **kwargs)
assert not isinstance(v, values.DistributedVariable)
index[d] = v
return index