aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py')
-rw-r--r--tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py12
1 files changed, 8 insertions, 4 deletions
diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
index 77079d0df9..9809204f8f 100644
--- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
+++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
@@ -143,8 +143,10 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy):
def _real_mirrored_creator(devices, *args, **kwargs):
"""Creates one MirroredVariable on the current worker."""
index = {}
+ unique_var_name = ops.get_default_graph().unique_name(
+ kwargs["name"], mark_as_used=False).rstrip("/")
collective_instance_key = self._collective_keys.get_instance_key(
- key_id=kwargs["name"])
+ key_id=unique_var_name)
if "initial_value" not in kwargs:
raise ValueError("Initial value must be specified.")
initial_value = kwargs["initial_value"]
@@ -188,6 +190,10 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy):
with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
v = next_creator(*args, **kwargs)
+ if i == 0:
+ actual_var_name = v.name.split(":")[0]
+ assert unique_var_name == actual_var_name, "%r vs %r" % (
+ unique_var_name, actual_var_name)
assert not isinstance(v, values.DistributedVariable)
index[d] = v
return index
@@ -210,7 +216,7 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy):
"""Configures the object.
Args:
- session_config: a @{tf.ConfigProto}
+ session_config: a `tf.ConfigProto`
cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the
cluster configurations.
task_type: the current task type, such as "worker".
@@ -229,8 +235,6 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy):
if not session_config or not self._cluster_spec:
return
- session_config.isolate_session_state = True
-
assert self._task_type
assert self._task_id is not None