aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/mirrored_strategy.py
diff options
context:
space:
mode:
authorGravatar Anjali Sridhar <anjalisridhar@google.com>2018-06-05 10:36:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-05 10:39:04 -0700
commita7c026e08864417b35dbe3c9e4b246725ad6ba59 (patch)
tree436670cfa69e4a79d9c0791de69d7a727daf40f7 /tensorflow/contrib/distribute/python/mirrored_strategy.py
parentc8090fa6acac1f9724671407964662137911921f (diff)
Respect name scopes opened in tower mode when creating vars in cross tower mode.
PiperOrigin-RevId: 199319758
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy.py')
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py35
1 files changed, 25 insertions, 10 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index 6eadba976b..cef0a2907b 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -118,7 +118,10 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
if i > 0:
# Give replicas meaningful distinct names:
var0name = index[devices[0]].name.split(":")[0]
- kwargs["name"] = "%s/replica_%d" % (var0name, i)
+ # We append a / to variable names created on towers with id > 0 to
+ # ensure that we ignore the name scope and instead use the given
+ # name as the absolute name of the variable.
+ kwargs["name"] = "%s/replica_%d/" % (var0name, i)
# Initialize replicas with the same value:
if context.executing_eagerly():
kwargs["initial_value"] = array_ops.identity(
@@ -258,8 +261,15 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
{t.device: t.merge_args for t in threads})
merge_kwargs = values.regroup(
{t.device: t.merge_kwargs for t in threads})
- merge_result = threads[0].merge_fn(
- self, *merge_args, **merge_kwargs)
+ # We capture the name_scope of the MTT when we call merge_fn
+ # to ensure that if we have opened a name scope in the MTT,
+ # it will be respected when executing the merge function. We only
+ # capture the name_scope from the first MTT and assume it is
+ # the same for all other MTTs.
+ mtt_captured_name_scope = threads[0].captured_name_scope
+ with ops.name_scope(mtt_captured_name_scope):
+ merge_result = threads[0].merge_fn(
+ self, *merge_args, **merge_kwargs)
for t in threads:
t.merge_result = values.select_device(t.device, merge_result)
finally:
@@ -428,6 +438,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
self.merge_args = None
self.merge_kwargs = None
self.merge_result = None
+ self.captured_name_scope = None
# We use a thread.Event for the main thread to signal when this
# thread should start running (`should_run`), and another for
# this thread to transfer control back to the main thread
@@ -451,13 +462,13 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
self._variable_creator_stack = self.graph._variable_creator_stack[:]
self._captured_var_scope = variable_scope.get_variable_scope()
# Adding a "/" at end lets us re-enter this scope later.
- self._captured_name_scope = self.graph.get_name_scope()
- if self._captured_name_scope:
- self._captured_name_scope += "/"
+ self._name_scope = self.graph.get_name_scope()
+ if self._name_scope:
+ self._name_scope += "/"
if self.tower_id > 0:
- if not self._captured_name_scope:
- self._captured_name_scope = ""
- self._captured_name_scope += "tower_%d/" % self.tower_id
+ if not self._name_scope:
+ self._name_scope = ""
+ self._name_scope += "tower_%d/" % self.tower_id
def run(self):
# pylint: disable=protected-access
@@ -473,7 +484,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
_enter_graph(self.graph), \
MirroredTowerContext(self.distribution, self.tower_id), \
ops.device(self.device), \
- ops.name_scope(self._captured_name_scope), \
+ ops.name_scope(self._name_scope), \
variable_scope.variable_scope(
self._captured_var_scope, reuse=self.tower_id > 0), \
variable_scope.variable_creator_scope(self.variable_creator_fn):
@@ -499,6 +510,10 @@ class MirroredTowerContext(distribute_lib.TowerContext):
t.merge_fn = fn
t.merge_args = args
t.merge_kwargs = kwargs
+ t.captured_name_scope = t.graph.get_name_scope()
+ # Adding a "/" at end lets us re-enter this scope later.
+ if t.captured_name_scope:
+ t.captured_name_scope += "/"
t.has_paused.set()
t.should_run.wait()
t.should_run.clear()