aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2018-03-07 12:03:56 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-07 12:10:42 -0800
commit37cef895bfe06913477b87917cbee7284aefa7cd (patch)
tree4f05a013578c0459a52fc5e6448bb3dfc2d04971 /tensorflow/python/framework
parent808b569e85df8d63590740f05bc14d964efc4801 (diff)
eager: Rename in_eager_mode to executing_eagerly and get rid of in_graph_mode.
This is in preparation to introduce one public, stable symbol: tf.executing_eagerly() (i.e., part of moving APIs related to eager execution from "contrib" to a namespace where we provide API stability guarantees) PiperOrigin-RevId: 188212646
Diffstat (limited to 'tensorflow/python/framework')
-rw-r--r--tensorflow/python/framework/constant_op.py2
-rw-r--r--tensorflow/python/framework/function.py6
-rw-r--r--tensorflow/python/framework/meta_graph.py4
-rw-r--r--tensorflow/python/framework/ops.py40
-rw-r--r--tensorflow/python/framework/ops_test.py25
-rw-r--r--tensorflow/python/framework/random_seed.py20
-rw-r--r--tensorflow/python/framework/random_seed_test.py8
-rw-r--r--tensorflow/python/framework/tensor_util.py2
-rw-r--r--tensorflow/python/framework/test_util.py2
9 files changed, 55 insertions, 54 deletions
diff --git a/tensorflow/python/framework/constant_op.py b/tensorflow/python/framework/constant_op.py
index d3d8c9c154..782b505d6c 100644
--- a/tensorflow/python/framework/constant_op.py
+++ b/tensorflow/python/framework/constant_op.py
@@ -181,7 +181,7 @@ def constant(value, dtype=None, shape=None, name="Const", verify_shape=False):
TypeError: if shape is incorrectly specified or unsupported.
"""
ctx = context.context()
- if not ctx.in_graph_mode():
+ if ctx.executing_eagerly():
t = convert_to_eager_tensor(value, ctx, dtype)
if shape is None:
return t
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index caa604999c..14d72d8a3d 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -489,10 +489,10 @@ class _DefinedFunction(object):
# Adds this function into 'g'.
# pylint: disable=protected-access
- if context.in_graph_mode():
- g._add_function(self)
- else:
+ if context.executing_eagerly():
context.context().add_function_def(self.definition)
+ else:
+ g._add_function(self)
# pylint: enable=protected-access
# Ensures related sub-routines are defined in 'g', too.
diff --git a/tensorflow/python/framework/meta_graph.py b/tensorflow/python/framework/meta_graph.py
index 4c1bd736d7..4bb9941bb7 100644
--- a/tensorflow/python/framework/meta_graph.py
+++ b/tensorflow/python/framework/meta_graph.py
@@ -695,7 +695,7 @@ def import_scoped_meta_graph(meta_graph_or_file,
Raises:
ValueError: If the graph_def contains unbound inputs.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise ValueError("Exporting/importing meta graphs is not supported when "
"eager execution is enabled.")
if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef):
@@ -856,7 +856,7 @@ def export_scoped_meta_graph(filename=None,
Raises:
ValueError: When the `GraphDef` is larger than 2GB.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise ValueError("Exporting/importing meta graphs is not supported when "
"Eager Execution is enabled.")
graph = graph or ops.get_default_graph()
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 2a8319a19f..8ff247fdb1 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -395,10 +395,10 @@ class Tensor(_TensorLike):
"Tensor._shape cannot be assigned, use Tensor.set_shape instead.")
def __iter__(self):
- if context.in_graph_mode():
+ if not context.executing_eagerly():
raise TypeError(
- "`Tensor` objects are not iterable when eager execution is not "
- "enabled. To iterate over this tensor use `tf.map_fn`.")
+ "Tensor objects are not iterable when eager execution is not "
+ "enabled. To iterate over this tensor use tf.map_fn.")
shape = self._shape_tuple()
if shape is None:
raise TypeError("Cannot iterate over a tensor with unknown shape.")
@@ -772,7 +772,7 @@ class _EagerTensorBase(Tensor):
six.raise_from(core._status_to_exception(e.code, e.message), None)
# Record the copy on tape and define backprop copy as well.
- if not context.in_graph_mode():
+ if context.executing_eagerly():
self_device = self.device
def grad_fun(dresult):
return [dresult._copy(device_name=self_device)]
@@ -993,7 +993,7 @@ def internal_convert_to_tensor(value,
"""
if ctx is None: ctx = context.context()
- if ctx.in_eager_mode():
+ if ctx.executing_eagerly():
# Fast path for EagerTensors that don't need any conversion.
if isinstance(value, EagerTensor):
# Note that we don't check that value's dtype matches the dtype
@@ -4797,15 +4797,15 @@ def device(device_name_or_function):
Raises:
RuntimeError: If eager execution is enabled and a function is passed in.
"""
- if context.in_graph_mode():
- return get_default_graph().device(device_name_or_function)
- else:
+ if context.executing_eagerly():
# TODO(agarwal): support device functions in EAGER mode.
if callable(device_name_or_function):
raise RuntimeError(
"tf.device does not support functions when eager execution "
"is enabled.")
return context.device(device_name_or_function)
+ else:
+ return get_default_graph().device(device_name_or_function)
@tf_export("container")
@@ -4824,7 +4824,12 @@ def container(container_name):
@tf_export("colocate_with")
def colocate_with(op, ignore_existing=False):
- if context.in_graph_mode():
+ if context.executing_eagerly():
+ if op is not None:
+ return device(op.device)
+ else:
+ return _NullContextmanager()
+ else:
default_graph = get_default_graph()
if isinstance(op, EagerTensor):
if default_graph.building_function:
@@ -4833,11 +4838,6 @@ def colocate_with(op, ignore_existing=False):
raise ValueError("Encountered an Eager-defined Tensor during graph "
"construction, but a function was not being built.")
return default_graph.colocate_with(op, ignore_existing)
- else:
- if op is not None:
- return device(op.device)
- else:
- return _NullContextmanager()
@tf_export("control_dependencies")
@@ -4857,10 +4857,10 @@ def control_dependencies(control_inputs):
A context manager that specifies control dependencies for all
operations constructed within the context.
"""
- if context.in_graph_mode():
- return get_default_graph().control_dependencies(control_inputs)
- else:
+ if context.executing_eagerly():
return _NullContextmanager()
+ else:
+ return get_default_graph().control_dependencies(control_inputs)
class _DefaultStack(threading.local):
@@ -5123,7 +5123,7 @@ def init_scope():
"""
# pylint: enable=g-doc-return-or-yield,line-too-long
- if context.in_eager_mode():
+ if context.executing_eagerly():
# Fastpath.
with tape.stop_recording():
yield
@@ -5705,7 +5705,7 @@ class name_scope(object): # pylint: disable=invalid-name
self._default_name = default_name
self._values = values
self._ctx = context.context()
- self._in_eager_mode = self._ctx.in_eager_mode()
+ self._in_eager_mode = self._ctx.executing_eagerly()
def __enter__(self):
"""Start the scope block.
@@ -5884,7 +5884,7 @@ def get_from_proto_function(collection_name):
def _assert_collection_is_ok(collection_name):
- if context.in_eager_mode():
+ if context.executing_eagerly():
if collection_name in GraphKeys._VARIABLE_COLLECTIONS: # pylint: disable=protected-access
raise ValueError("When Eager Execution is enabled, variable "
"collections are not supported.")
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index 55576f0e88..c294f830bc 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -1763,7 +1763,13 @@ class ControlDependenciesTest(test_util.TensorFlowTestCase):
return constant_op.constant(2.0)
future.calls = 0
- if context.in_graph_mode():
+ if context.executing_eagerly():
+ a = constant_op.constant(1.0)
+ b = future()
+ with ops.control_dependencies([a, b]):
+ c = constant_op.constant(3.0)
+ self.assertEqual(future.calls, 1)
+ else:
g = ops.Graph()
with g.as_default():
a = constant_op.constant(1.0)
@@ -1772,12 +1778,6 @@ class ControlDependenciesTest(test_util.TensorFlowTestCase):
c = constant_op.constant(3.0)
self.assertEqual(c.op.control_inputs, [a.op, b.op])
self.assertEqual(future.calls, 1)
- else:
- a = constant_op.constant(1.0)
- b = future()
- with ops.control_dependencies([a, b]):
- c = constant_op.constant(3.0)
- self.assertEqual(future.calls, 1)
def testBasicWithConversion(self):
g = ops.Graph()
@@ -2150,11 +2150,11 @@ class InitScopeTest(test_util.TensorFlowTestCase):
with ops.init_scope():
# Because g is building a function, init_scope should
# escape out to the eager context.
- self.assertTrue(context.in_eager_mode())
+ self.assertTrue(context.executing_eagerly())
# g should be reinstated as the default graph, and the
# graph context should be re-entered.
self.assertIs(g, ops.get_default_graph())
- self.assertTrue(context.in_graph_mode())
+ self.assertFalse(context.executing_eagerly())
def testStaysInEagerWhenOnlyEagerContextActive(self):
with context.eager_mode():
@@ -2277,12 +2277,13 @@ class InitScopeTest(test_util.TensorFlowTestCase):
with context.eager_mode():
def foo():
with ops.name_scope("inner"), ops.init_scope():
- if context.in_graph_mode():
- self.assertEqual(ops.get_name_scope(), "inner")
- else:
+ if context.executing_eagerly():
# A trailing slash is always appended when eager execution is
# enabled.
self.assertEqual(context.context().scope_name, "inner/")
+ else:
+ self.assertEqual(ops.get_name_scope(), "inner")
+
foo()
self.assertEqual(ops.get_name_scope(), "")
foo_compiled = eager_function.defun(foo)
diff --git a/tensorflow/python/framework/random_seed.py b/tensorflow/python/framework/random_seed.py
index 1e74a790a3..b724432e00 100644
--- a/tensorflow/python/framework/random_seed.py
+++ b/tensorflow/python/framework/random_seed.py
@@ -52,20 +52,20 @@ def get_seed(op_seed):
A tuple of two integers that should be used for the local seed of this
operation.
"""
- is_graph_mode = context.in_graph_mode()
+ eager = context.executing_eagerly()
- if is_graph_mode:
- global_seed = ops.get_default_graph().seed
- else:
+ if eager:
global_seed = context.global_seed()
+ else:
+ global_seed = ops.get_default_graph().seed
if global_seed is not None:
if op_seed is None:
# pylint: disable=protected-access
- if is_graph_mode:
- op_seed = ops.get_default_graph()._last_id
- else:
+ if eager:
op_seed = context.internal_operation_seed()
+ else:
+ op_seed = ops.get_default_graph()._last_id
seeds = _truncate_seed(global_seed), _truncate_seed(op_seed)
else:
@@ -176,7 +176,7 @@ def set_random_seed(seed):
Args:
seed: integer.
"""
- if context.in_graph_mode():
- ops.get_default_graph().seed = seed
- else:
+ if context.executing_eagerly():
context.set_global_seed(seed)
+ else:
+ ops.get_default_graph().seed = seed
diff --git a/tensorflow/python/framework/random_seed_test.py b/tensorflow/python/framework/random_seed_test.py
index b4c98ab8b2..1944922686 100644
--- a/tensorflow/python/framework/random_seed_test.py
+++ b/tensorflow/python/framework/random_seed_test.py
@@ -40,13 +40,13 @@ class RandomSeedTest(test.TestCase):
((2**31 - 1, 0), (0, 2**31 - 1)), # Don't wrap to (0, 0) either
((0, 2**31 - 1), (0, 2**31 - 1)), # Wrapping for the other argument
]
- if context.in_graph_mode():
- # 0 will be the default_graph._lastid.
- test_cases.append(((1, None), (1, 0)))
- else:
+ if context.executing_eagerly():
# operation seed is random number generated based on global seed.
# it's not tested due to possibility of platform or version difference.
pass
+ else:
+ # 0 will be the default_graph._lastid.
+ test_cases.append(((1, None), (1, 0)))
for tc in test_cases:
tinput, toutput = tc[0], tc[1]
random_seed.set_random_seed(tinput[0])
diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py
index 135562e831..984bcecdfe 100644
--- a/tensorflow/python/framework/tensor_util.py
+++ b/tensorflow/python/framework/tensor_util.py
@@ -828,7 +828,7 @@ def constant_value_as_shape(tensor): # pylint: disable=invalid-name
Returns:
A `TensorShape` based on the constant value of the given `tensor`.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
return tensor_shape.as_shape(
[dim if dim != -1 else None for dim in tensor.numpy()])
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 1c8398e686..9fc1154201 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -816,7 +816,7 @@ class TensorFlowTestCase(googletest.TestCase):
Returns:
tensors numpy values.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
return self._eval_helper(tensors)
else:
sess = ops.get_default_session()