aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/mirrored_strategy.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy.py')
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py13
1 files changed, 12 insertions, 1 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index 89f2c431fe..14dbbd6e27 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import contextlib
import threading
import six
@@ -39,6 +40,16 @@ from tensorflow.python.training import distribute as distribute_lib
# TODO(josh11b): Replace asserts in this file with if ...: raise ...
+@contextlib.contextmanager
+def _enter_graph(g):
+ if context.executing_eagerly():
+ with g.as_default(), context.eager_mode():
+ yield
+ else:
+ with g.as_default():
+ yield
+
+
def _cpu_device(device):
cpu_device = tf_device.DeviceSpec.from_string(device)
cpu_device.merge_from(tf_device.DeviceSpec(device_type="CPU", device_index=0))
@@ -458,7 +469,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
with self.coord.stop_on_exception(), \
context.context()._mode(self.context_mode), \
context.context().device_policy(self.context_device_policy), \
- self.graph.as_default(), \
+ _enter_graph(self.graph), \
MirroredTowerContext(self.distribution, self.tower_id), \
ops.device(self.device), \
ops.name_scope(self._captured_name_scope), \