diff options
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/mirrored_strategy.py | 13 |
1 files changed, 12 insertions, 1 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 89f2c431fe..14dbbd6e27 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import contextlib import threading import six @@ -39,6 +40,16 @@ from tensorflow.python.training import distribute as distribute_lib # TODO(josh11b): Replace asserts in this file with if ...: raise ... +@contextlib.contextmanager +def _enter_graph(g): + if context.executing_eagerly(): + with g.as_default(), context.eager_mode(): + yield + else: + with g.as_default(): + yield + + def _cpu_device(device): cpu_device = tf_device.DeviceSpec.from_string(device) cpu_device.merge_from(tf_device.DeviceSpec(device_type="CPU", device_index=0)) @@ -458,7 +469,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): with self.coord.stop_on_exception(), \ context.context()._mode(self.context_mode), \ context.context().device_policy(self.context_device_policy), \ - self.graph.as_default(), \ + _enter_graph(self.graph), \ MirroredTowerContext(self.distribution, self.tower_id), \ ops.device(self.device), \ ops.name_scope(self._captured_name_scope), \ |