diff options
Diffstat (limited to 'tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py | 12 |
1 files changed, 8 insertions, 4 deletions
diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py index 77079d0df9..9809204f8f 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py @@ -143,8 +143,10 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): def _real_mirrored_creator(devices, *args, **kwargs): """Creates one MirroredVariable on the current worker.""" index = {} + unique_var_name = ops.get_default_graph().unique_name( + kwargs["name"], mark_as_used=False).rstrip("/") collective_instance_key = self._collective_keys.get_instance_key( - key_id=kwargs["name"]) + key_id=unique_var_name) if "initial_value" not in kwargs: raise ValueError("Initial value must be specified.") initial_value = kwargs["initial_value"] @@ -188,6 +190,10 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): v = next_creator(*args, **kwargs) + if i == 0: + actual_var_name = v.name.split(":")[0] + assert unique_var_name == actual_var_name, "%r vs %r" % ( + unique_var_name, actual_var_name) assert not isinstance(v, values.DistributedVariable) index[d] = v return index @@ -210,7 +216,7 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): """Configures the object. Args: - session_config: a @{tf.ConfigProto} + session_config: a `tf.ConfigProto` cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the cluster configurations. task_type: the current task type, such as "worker". @@ -229,8 +235,6 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): if not session_config or not self._cluster_spec: return - session_config.isolate_session_state = True - assert self._task_type assert self._task_id is not None |