aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/mirrored_strategy.py
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-05-25 13:20:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-25 13:23:08 -0700
commit68430112b2ca5c160db6dd412d43f572ec69e72f (patch)
tree90b2c71f25bd56ddcc179bfd0855bf858f9a6384 /tensorflow/contrib/distribute/python/mirrored_strategy.py
parentb6ae98b4ac1ec3051d81f3133b827d6bb305aa2b (diff)
Public API to switch between eager execution and graph building.
Now, after tf.enable_eager_execution() has been executed, entering the context manager of a tf.Graph will enable graph mode. So, for example ``` tf.enable_eager_execution() with tf.Graph().as_default(): c = tf.constant(1.0) # this is a graph tensor c2 = tf.constant(1.0) # this is an eager tensor ``` The main use-case of this is allowing documentation writers to make a single notebook which starts with eager execution and seamlessly transitions to building graphs. This also makes many explicit enablings of graph mode in the code redundant (a cleanup cl will follow). PiperOrigin-RevId: 198092991
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy.py')
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py13
1 files changed, 12 insertions, 1 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index 89f2c431fe..14dbbd6e27 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import contextlib
import threading
import six
@@ -39,6 +40,16 @@ from tensorflow.python.training import distribute as distribute_lib
# TODO(josh11b): Replace asserts in this file with if ...: raise ...
+@contextlib.contextmanager
+def _enter_graph(g):
+ if context.executing_eagerly():
+ with g.as_default(), context.eager_mode():
+ yield
+ else:
+ with g.as_default():
+ yield
+
+
def _cpu_device(device):
cpu_device = tf_device.DeviceSpec.from_string(device)
cpu_device.merge_from(tf_device.DeviceSpec(device_type="CPU", device_index=0))
@@ -458,7 +469,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
with self.coord.stop_on_exception(), \
context.context()._mode(self.context_mode), \
context.context().device_policy(self.context_device_policy), \
- self.graph.as_default(), \
+ _enter_graph(self.graph), \
MirroredTowerContext(self.distribution, self.tower_id), \
ops.device(self.device), \
ops.name_scope(self._captured_name_scope), \