diff options
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/mirrored_strategy.py | 23 |
1 files changed, 14 insertions, 9 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index a32424b316..0f82508428 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -293,7 +293,8 @@ def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs): collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) for v in index.values(): - l.remove(v) + if v in l: + l.remove(v) g.add_to_collections(collections, result) elif ops.GraphKeys.GLOBAL_STEP in collections: ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result) @@ -461,16 +462,20 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): # name as the absolute name of the variable. kwargs["name"] = "%s/replica_%d/" % (var0name, i) # Initialize replicas with the same value: - if context.executing_eagerly(): - kwargs["initial_value"] = array_ops.identity( - index[devices[0]].value()) - else: - def initial_value_fn(device=d): + def initial_value_fn(device=d): + if context.executing_eagerly(): + init_value = index[devices[0]].value() + return array_ops.identity(init_value) + else: with ops.device(device): - return array_ops.identity(index[devices[0]].initial_value) - kwargs["initial_value"] = initial_value_fn + init_value = index[devices[0]].initial_value + return array_ops.identity(init_value) + kwargs["initial_value"] = initial_value_fn with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): - v = next_creator(*args, **kwargs) + # Don't record operations (e.g. other variable reads) during + # variable creation. + with tape.stop_recording(): + v = next_creator(*args, **kwargs) assert not isinstance(v, values.DistributedVariable) index[d] = v return index |