aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2018-03-07 12:03:56 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-07 12:10:42 -0800
commit37cef895bfe06913477b87917cbee7284aefa7cd (patch)
tree4f05a013578c0459a52fc5e6448bb3dfc2d04971 /tensorflow/python/ops
parent808b569e85df8d63590740f05bc14d964efc4801 (diff)
eager: Rename in_eager_mode to executing_eagerly and get rid of in_graph_mode.
This is in preparation to introduce one public, stable symbol: tf.executing_eagerly() (i.e., part of moving APIs related to eager execution from "contrib" to a namespace where we provide API stability guarantees) PiperOrigin-RevId: 188212646
Diffstat (limited to 'tensorflow/python/ops')
-rw-r--r--tensorflow/python/ops/array_grad.py8
-rw-r--r--tensorflow/python/ops/array_ops.py29
-rw-r--r--tensorflow/python/ops/check_ops.py41
-rw-r--r--tensorflow/python/ops/control_flow_ops.py14
-rw-r--r--tensorflow/python/ops/custom_gradient.py2
-rw-r--r--tensorflow/python/ops/data_flow_ops.py40
-rw-r--r--tensorflow/python/ops/functional_ops.py8
-rw-r--r--tensorflow/python/ops/gradients_impl.py9
-rw-r--r--tensorflow/python/ops/io_ops.py2
-rw-r--r--tensorflow/python/ops/lookup_ops.py8
-rw-r--r--tensorflow/python/ops/losses/losses_impl.py2
-rw-r--r--tensorflow/python/ops/math_grad.py8
-rw-r--r--tensorflow/python/ops/math_ops.py10
-rw-r--r--tensorflow/python/ops/math_ops_test.py4
-rw-r--r--tensorflow/python/ops/metrics_impl.py60
-rw-r--r--tensorflow/python/ops/nn_grad.py2
-rw-r--r--tensorflow/python/ops/nn_ops.py9
-rw-r--r--tensorflow/python/ops/numerics.py2
-rw-r--r--tensorflow/python/ops/resource_variable_ops.py41
-rw-r--r--tensorflow/python/ops/rnn.py10
-rw-r--r--tensorflow/python/ops/rnn_cell_impl.py15
-rw-r--r--tensorflow/python/ops/script_ops.py2
-rw-r--r--tensorflow/python/ops/state_ops.py2
-rw-r--r--tensorflow/python/ops/template.py6
-rw-r--r--tensorflow/python/ops/tensor_array_ops.py10
-rw-r--r--tensorflow/python/ops/variable_scope.py38
-rw-r--r--tensorflow/python/ops/variables.py29
27 files changed, 207 insertions, 204 deletions
diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py
index 925cf8ef32..3c6a5c9e56 100644
--- a/tensorflow/python/ops/array_grad.py
+++ b/tensorflow/python/ops/array_grad.py
@@ -80,7 +80,7 @@ def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index):
def _ExtractInputShapes(inputs):
"""Extract the shapes of a set of input tensors."""
- if not context.in_graph_mode():
+ if context.executing_eagerly():
return array_ops.shape_n(inputs)
sizes = []
fully_known = True
@@ -106,7 +106,7 @@ def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index):
out_grads = []
if isinstance(grad, ops.Tensor):
- if context.in_eager_mode():
+ if context.executing_eagerly():
# Using mod here for convenience since concat_dim is already verified
# in concat implementation to be within the allowed [-rank, rank) range.
non_neg_concat_dim = (
@@ -428,7 +428,7 @@ def _GatherV2Grad(op, grad):
# For axis 0 gathers, build an appropriately shaped IndexedSlices.
if axis_static == 0:
- if context.in_eager_mode():
+ if context.executing_eagerly():
params_tail_shape = params_shape.cpu()[1:]
else:
params_tail_shape = params_shape[1:]
@@ -578,7 +578,7 @@ def _TileGrad(op, grad):
axes = math_ops.range(0, array_ops.size(split_shape), 2)
input_grad = math_ops.reduce_sum(array_ops.reshape(grad, split_shape), axes)
# Fix shape inference
- if context.in_graph_mode():
+ if not context.executing_eagerly():
input_grad.set_shape(op.inputs[0].get_shape())
return [input_grad, None]
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 9108fe759b..b4e1b9d781 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -128,9 +128,7 @@ def identity(input, name=None): # pylint: disable=redefined-builtin
Returns:
A `Tensor`. Has the same type as `input`.
"""
- if context.in_graph_mode():
- return gen_array_ops.identity(input, name=name)
- else:
+ if context.executing_eagerly():
input = ops.convert_to_tensor(input)
in_device = input.device
# TODO(ashankar): Does 'identity' need to invoke execution callbacks?
@@ -140,6 +138,8 @@ def identity(input, name=None): # pylint: disable=redefined-builtin
if context_device != in_device:
return input._copy() # pylint: disable=protected-access
return input
+ else:
+ return gen_array_ops.identity(input, name=name)
# pylint: disable=redefined-builtin,protected-access
@@ -305,7 +305,7 @@ def shape_internal(input, name=None, optimize=True, out_type=dtypes.int32):
sparse_tensor.SparseTensorValue)):
return gen_math_ops.cast(input.dense_shape, out_type)
else:
- if context.in_graph_mode():
+ if not context.executing_eagerly():
input_tensor = ops.convert_to_tensor(input)
input_shape = input_tensor.get_shape()
if optimize and input_shape.is_fully_defined():
@@ -330,7 +330,7 @@ def shape_n(input, out_type=dtypes.int32, name=None):
"""
output = gen_array_ops.shape_n(input, out_type=out_type, name=name)
- if context.in_graph_mode():
+ if not context.executing_eagerly():
for i, input_tensor in enumerate(input):
input_tensor = ops.convert_to_tensor(input_tensor)
input_shape = input_tensor.get_shape()
@@ -385,9 +385,8 @@ def size_internal(input, name=None, optimize=True, out_type=dtypes.int32):
Returns:
A `Tensor` of type `out_type`. Defaults to `tf.int32`.
"""
- if context.in_eager_mode() and not isinstance(
- input, (sparse_tensor.SparseTensor,
- sparse_tensor.SparseTensorValue)):
+ if context.executing_eagerly() and not isinstance(
+ input, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
return np.prod(ops.convert_to_tensor(input)._shape_tuple()) # pylint: disable=protected-access
with ops.name_scope(name, "Size", [input]) as name:
if isinstance(input, (sparse_tensor.SparseTensor,
@@ -783,7 +782,7 @@ def strided_slice(input_,
new_axis_mask=new_axis_mask,
shrink_axis_mask=shrink_axis_mask)
- if context.in_graph_mode():
+ if not context.executing_eagerly():
# TODO(apassos) In eager mode assignment will be done by overriding
# __setitem__ instead.
op.assign = assign
@@ -1457,7 +1456,7 @@ def transpose(a, perm=None, name="transpose", conjugate=False):
ret = transpose_fn(a, perm, name=name)
# NOTE(mrry): Setting the shape explicitly because
# reverse is not handled by the shape function.
- if context.in_graph_mode():
+ if not context.executing_eagerly():
input_shape = ret.op.inputs[0].get_shape().dims
if input_shape is not None:
ret.set_shape(input_shape[::-1])
@@ -1622,7 +1621,7 @@ def zeros_like(tensor, dtype=None, name=None, optimize=True):
with ops.name_scope(name, "zeros_like", [tensor]) as name:
tensor = ops.convert_to_tensor(tensor, name="tensor")
- if context.in_eager_mode():
+ if context.executing_eagerly():
if dtype is not None and dtype != tensor.dtype:
return zeros(
shape_internal(tensor, optimize=optimize), dtype=dtype, name=name)
@@ -1678,7 +1677,7 @@ def ones_like(tensor, dtype=None, name=None, optimize=True):
if dtype is None:
dtype = tensor.dtype
ret = ones(ones_shape, dtype=dtype, name=name)
- if context.in_graph_mode():
+ if not context.executing_eagerly():
ret.set_shape(tensor.get_shape())
return ret
@@ -1759,7 +1758,7 @@ def placeholder(dtype, shape=None, name=None):
Raises:
RuntimeError: if eager execution is enabled
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError("tf.placeholder() is not compatible with "
"eager execution.")
@@ -1822,7 +1821,7 @@ def sparse_placeholder(dtype, shape=None, name=None):
Raises:
RuntimeError: if eager execution is enabled
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError("tf.placeholder() is not compatible with "
"eager execution.")
@@ -1921,7 +1920,7 @@ def pad(tensor, paddings, mode="CONSTANT", name=None, constant_values=0): # pyl
raise ValueError("Unknown padding mode: %s" % mode)
# Restore shape information where possible.
- if context.in_graph_mode():
+ if not context.executing_eagerly():
paddings_constant = tensor_util.constant_value(
result.op.inputs[1], partial=True)
input_shape = result.op.inputs[0].shape
diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py
index 0fd6e29a49..7d6e047d7c 100644
--- a/tensorflow/python/ops/check_ops.py
+++ b/tensorflow/python/ops/check_ops.py
@@ -169,7 +169,7 @@ def assert_negative(x, data=None, summarize=None, message=None, name=None):
with ops.name_scope(name, 'assert_negative', [x, data]):
x = ops.convert_to_tensor(x, name='x')
if data is None:
- if context.in_eager_mode():
+ if context.executing_eagerly():
name = _shape_and_dtype_str(x)
else:
name = x.name
@@ -210,7 +210,7 @@ def assert_positive(x, data=None, summarize=None, message=None, name=None):
with ops.name_scope(name, 'assert_positive', [x, data]):
x = ops.convert_to_tensor(x, name='x')
if data is None:
- if context.in_eager_mode():
+ if context.executing_eagerly():
name = _shape_and_dtype_str(x)
else:
name = x.name
@@ -251,7 +251,7 @@ def assert_non_negative(x, data=None, summarize=None, message=None, name=None):
with ops.name_scope(name, 'assert_non_negative', [x, data]):
x = ops.convert_to_tensor(x, name='x')
if data is None:
- if context.in_eager_mode():
+ if context.executing_eagerly():
name = _shape_and_dtype_str(x)
else:
name = x.name
@@ -293,7 +293,7 @@ def assert_non_positive(x, data=None, summarize=None, message=None, name=None):
with ops.name_scope(name, 'assert_non_positive', [x, data]):
x = ops.convert_to_tensor(x, name='x')
if data is None:
- if context.in_eager_mode():
+ if context.executing_eagerly():
name = _shape_and_dtype_str(x)
else:
name = x.name
@@ -343,7 +343,7 @@ def assert_equal(x, y, data=None, summarize=None, message=None, name=None):
x = ops.convert_to_tensor(x, name='x')
y = ops.convert_to_tensor(y, name='y')
- if context.in_eager_mode():
+ if context.executing_eagerly():
eq = math_ops.equal(x, y)
condition = math_ops.reduce_all(eq)
if not condition:
@@ -435,7 +435,7 @@ def assert_none_equal(
with ops.name_scope(name, 'assert_none_equal', [x, y, data]):
x = ops.convert_to_tensor(x, name='x')
y = ops.convert_to_tensor(y, name='y')
- if context.in_eager_mode():
+ if context.executing_eagerly():
x_name = _shape_and_dtype_str(x)
y_name = _shape_and_dtype_str(y)
else:
@@ -512,7 +512,7 @@ def assert_near(
rtol = ops.convert_to_tensor(rtol, name='rtol', dtype=x.dtype)
atol = ops.convert_to_tensor(atol, name='atol', dtype=x.dtype)
- if context.in_eager_mode():
+ if context.executing_eagerly():
x_name = _shape_and_dtype_str(x)
y_name = _shape_and_dtype_str(y)
else:
@@ -562,7 +562,7 @@ def assert_less(x, y, data=None, summarize=None, message=None, name=None):
with ops.name_scope(name, 'assert_less', [x, y, data]):
x = ops.convert_to_tensor(x, name='x')
y = ops.convert_to_tensor(y, name='y')
- if context.in_eager_mode():
+ if context.executing_eagerly():
x_name = _shape_and_dtype_str(x)
y_name = _shape_and_dtype_str(y)
else:
@@ -610,7 +610,7 @@ def assert_less_equal(x, y, data=None, summarize=None, message=None, name=None):
with ops.name_scope(name, 'assert_less_equal', [x, y, data]):
x = ops.convert_to_tensor(x, name='x')
y = ops.convert_to_tensor(y, name='y')
- if context.in_eager_mode():
+ if context.executing_eagerly():
x_name = _shape_and_dtype_str(x)
y_name = _shape_and_dtype_str(y)
else:
@@ -658,7 +658,7 @@ def assert_greater(x, y, data=None, summarize=None, message=None, name=None):
with ops.name_scope(name, 'assert_greater', [x, y, data]):
x = ops.convert_to_tensor(x, name='x')
y = ops.convert_to_tensor(y, name='y')
- if context.in_eager_mode():
+ if context.executing_eagerly():
x_name = _shape_and_dtype_str(x)
y_name = _shape_and_dtype_str(y)
else:
@@ -708,7 +708,7 @@ def assert_greater_equal(x, y, data=None, summarize=None, message=None,
with ops.name_scope(name, 'assert_greater_equal', [x, y, data]):
x = ops.convert_to_tensor(x, name='x')
y = ops.convert_to_tensor(y, name='y')
- if context.in_eager_mode():
+ if context.executing_eagerly():
x_name = _shape_and_dtype_str(x)
y_name = _shape_and_dtype_str(y)
else:
@@ -808,7 +808,7 @@ def assert_rank(x, rank, data=None, summarize=None, message=None, name=None):
static_condition = lambda actual_rank, given_rank: actual_rank == given_rank
dynamic_condition = math_ops.equal
- if context.in_eager_mode():
+ if context.executing_eagerly():
name = ''
else:
name = x.name
@@ -873,7 +873,7 @@ def assert_rank_at_least(
static_condition = lambda actual_rank, given_rank: actual_rank >= given_rank
dynamic_condition = math_ops.greater_equal
- if context.in_eager_mode():
+ if context.executing_eagerly():
name = ''
else:
name = x.name
@@ -1001,7 +1001,7 @@ def assert_rank_in(
ranks = tuple([ops.convert_to_tensor(rank, name='rank') for rank in ranks])
message = message or ''
- if context.in_eager_mode():
+ if context.executing_eagerly():
name = ''
else:
name = x.name
@@ -1054,7 +1054,7 @@ def assert_integer(x, message=None, name=None):
with ops.name_scope(name, 'assert_integer', [x]):
x = ops.convert_to_tensor(x, name='x')
if not x.dtype.is_integer:
- if context.in_eager_mode():
+ if context.executing_eagerly():
name = 'tensor'
else:
name = x.name
@@ -1087,12 +1087,11 @@ def assert_type(tensor, tf_type, message=None, name=None):
with ops.name_scope(name, 'assert_type', [tensor]):
tensor = ops.convert_to_tensor(tensor, name='tensor')
if tensor.dtype != tf_type:
- if context.in_graph_mode():
- raise TypeError(
- '%s %s must be of type %s' % (message, tensor.name, tf_type))
+ if context.executing_eagerly():
+ raise TypeError('%s tensor must be of type %s' % (message, tf_type))
else:
- raise TypeError(
- '%s tensor must be of type %s' % (message, tf_type))
+ raise TypeError('%s %s must be of type %s' % (message, tensor.name,
+ tf_type))
return control_flow_ops.no_op('statically_determined_correct_type')
@@ -1240,7 +1239,7 @@ def assert_scalar(tensor, name=None):
tensor = ops.convert_to_tensor(tensor, name=name_scope)
shape = tensor.get_shape()
if shape.ndims != 0:
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise ValueError('Expected scalar shape, saw shape: %s.'
% (shape,))
else:
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index 4e524846cc..a2f52de749 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -152,7 +152,7 @@ def Assert(condition, data, summarize=None, name=None):
@compatibility{eager} `tf.errors.InvalidArgumentError` if `condition`
is not true
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
if not condition:
xs = ops.convert_n_to_tensor(data)
data_str = [_summarize_eager(x, summarize) for x in xs]
@@ -178,7 +178,7 @@ def Assert(condition, data, summarize=None, name=None):
condition, data, summarize, name="Assert")
guarded_assert = cond(condition, no_op, true_assert, name="AssertGuard")
- if context.in_eager_mode():
+ if context.executing_eagerly():
return
return guarded_assert.op
@@ -2025,7 +2025,7 @@ def cond(pred,
raise TypeError("false_fn must be callable.")
with ops.name_scope(name, "cond", [pred]):
- if context.in_eager_mode():
+ if context.executing_eagerly():
if pred:
return _UnpackIfSingleton(true_fn())
return _UnpackIfSingleton(false_fn())
@@ -3177,7 +3177,7 @@ def while_loop(cond,
math_ops.logical_and(i < maximum_iterations, orig_cond(*lv)))
body = lambda i, lv: (i + 1, orig_body(*lv))
- if context.in_eager_mode():
+ if context.executing_eagerly():
while cond(*loop_vars):
loop_vars = body(*loop_vars)
if maximum_iterations is not None:
@@ -3271,7 +3271,7 @@ def with_dependencies(dependencies, output_tensor, name=None):
Raises:
TypeError: if `output_tensor` is not a `Tensor` or `IndexedSlices`.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
return output_tensor
with ops.name_scope(name, "control_dependency",
list(dependencies) + [output_tensor]) as name:
@@ -3316,7 +3316,7 @@ def group(*inputs, **kwargs):
Raises:
ValueError: If an unknown keyword argument is provided.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
return None
name = kwargs.pop("name", None)
if kwargs:
@@ -3396,7 +3396,7 @@ def tuple(tensors, name=None, control_inputs=None): # pylint: disable=redefined
objects.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
return tensors
with ops.name_scope(name, "tuple", tensors) as name:
tensors = [t if (isinstance(t, ops.Operation)
diff --git a/tensorflow/python/ops/custom_gradient.py b/tensorflow/python/ops/custom_gradient.py
index f199ba8fd4..9eacac1b37 100644
--- a/tensorflow/python/ops/custom_gradient.py
+++ b/tensorflow/python/ops/custom_gradient.py
@@ -92,7 +92,7 @@ def custom_gradient(f):
def decorated(*args, **kwargs):
"""Decorated function with custom gradient."""
- if context.in_graph_mode():
+ if not context.executing_eagerly():
if kwargs:
raise ValueError(
"The custom_gradient decorator currently suports keywords "
diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py
index 052caffd49..d2cc87555f 100644
--- a/tensorflow/python/ops/data_flow_ops.py
+++ b/tensorflow/python/ops/data_flow_ops.py
@@ -159,7 +159,7 @@ class QueueBase(object):
ValueError: If one of the arguments is invalid.
RuntimeError: If eager execution is enabled.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError(
"Queues are not supported when eager execution is enabled. "
"Instead, please use tf.data to get data into your model.")
@@ -177,10 +177,10 @@ class QueueBase(object):
else:
self._names = None
self._queue_ref = queue_ref
- if context.in_graph_mode():
- self._name = self._queue_ref.op.name.split("/")[-1]
- else:
+ if context.executing_eagerly():
self._name = context.context().scope_name
+ else:
+ self._name = self._queue_ref.op.name.split("/")[-1]
@staticmethod
def from_list(index, queues):
@@ -231,9 +231,9 @@ class QueueBase(object):
@property
def name(self):
"""The name of the underlying queue."""
- if context.in_graph_mode():
- return self._queue_ref.op.name
- return self._name
+ if context.executing_eagerly():
+ return self._name
+ return self._queue_ref.op.name
@property
def dtypes(self):
@@ -444,7 +444,7 @@ class QueueBase(object):
# NOTE(mrry): Not using a shape function because we need access to
# the `QueueBase` object.
- if context.in_graph_mode():
+ if not context.executing_eagerly():
op = ret[0].op
for output, shape in zip(op.values(), self._shapes):
output.set_shape(shape)
@@ -484,7 +484,7 @@ class QueueBase(object):
# NOTE(mrry): Not using a shape function because we need access to
# the Queue object.
- if context.in_graph_mode():
+ if not context.executing_eagerly():
op = ret[0].op
batch_dim = tensor_shape.Dimension(
tensor_util.constant_value(op.inputs[1]))
@@ -528,7 +528,7 @@ class QueueBase(object):
# NOTE(mrry): Not using a shape function because we need access to
# the Queue object.
- if context.in_graph_mode():
+ if not context.executing_eagerly():
op = ret[0].op
for output, shape in zip(op.values(), self._shapes):
output.set_shape(tensor_shape.TensorShape([None]).concatenate(shape))
@@ -990,10 +990,10 @@ class Barrier(object):
shapes=self._shapes,
shared_name=shared_name,
name=name)
- if context.in_graph_mode():
- self._name = self._barrier_ref.op.name.split("/")[-1]
- else:
+ if context.executing_eagerly():
self._name = context.context().scope_name
+ else:
+ self._name = self._barrier_ref.op.name.split("/")[-1]
@property
def barrier_ref(self):
@@ -1003,9 +1003,9 @@ class Barrier(object):
@property
def name(self):
"""The name of the underlying barrier."""
- if context.in_graph_mode():
- return self._barrier_ref.op.name
- return self._name
+ if context.executing_eagerly():
+ return self._name
+ return self._barrier_ref.op.name
def insert_many(self, component_index, keys, values, name=None):
"""For each key, assigns the respective value to the specified component.
@@ -1083,7 +1083,7 @@ class Barrier(object):
# NOTE(mrry): Not using a shape function because we need access to
# the Barrier object.
- if context.in_graph_mode():
+ if not context.executing_eagerly():
op = ret[0].op
if allow_small_batch:
batch_dim = None
@@ -1183,10 +1183,10 @@ class ConditionalAccumulatorBase(object):
else:
self._shape = tensor_shape.unknown_shape()
self._accumulator_ref = accumulator_ref
- if context.in_graph_mode():
- self._name = self._accumulator_ref.op.name.split("/")[-1]
- else:
+ if context.executing_eagerly():
self._name = context.context().scope_name
+ else:
+ self._name = self._accumulator_ref.op.name.split("/")[-1]
@property
def accumulator_ref(self):
diff --git a/tensorflow/python/ops/functional_ops.py b/tensorflow/python/ops/functional_ops.py
index 09a0e345f2..8f5673597e 100644
--- a/tensorflow/python/ops/functional_ops.py
+++ b/tensorflow/python/ops/functional_ops.py
@@ -90,7 +90,7 @@ def foldl(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
if not callable(fn):
raise TypeError("fn must be callable.")
- in_graph_mode = context.in_graph_mode()
+ in_graph_mode = not context.executing_eagerly()
with ops.name_scope(name, "foldl", [elems]):
# TODO(akshayka): Remove the in_graph_mode check once caching devices are
# supported in Eager
@@ -178,7 +178,7 @@ def foldr(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
if not callable(fn):
raise TypeError("fn must be callable.")
- in_graph_mode = context.in_graph_mode()
+ in_graph_mode = not context.executing_eagerly()
with ops.name_scope(name, "foldr", [elems]):
# TODO(akshayka): Remove the in_graph_mode check once caching devices are
# supported in Eager
@@ -343,7 +343,7 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
elems_flat = input_flatten(elems)
- in_graph_mode = context.in_graph_mode()
+ in_graph_mode = not context.executing_eagerly()
with ops.name_scope(name, "map", elems_flat):
# TODO(akshayka): Remove the in_graph_mode check once caching devices are
# supported in Eager
@@ -536,7 +536,7 @@ def scan(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
elems_flat = input_flatten(elems)
- in_graph_mode = context.in_graph_mode()
+ in_graph_mode = not context.executing_eagerly()
with ops.name_scope(name, "scan", elems_flat):
# TODO(akshayka): Remove the in_graph_mode check once caching devices are
# supported in Eager
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py
index b678090542..44473ec69c 100644
--- a/tensorflow/python/ops/gradients_impl.py
+++ b/tensorflow/python/ops/gradients_impl.py
@@ -86,7 +86,7 @@ def _IndexedSlicesToTensor(value, dtype=None, name=None, as_ref=False):
% str(value))
# TODO(mrry): Consider adding static shape information to
# IndexedSlices, to avoid using numpy here.
- if context.in_graph_mode():
+ if not context.executing_eagerly():
dense_shape_value = tensor_util.constant_value(value.dense_shape)
if dense_shape_value is not None:
num_elements = np.prod(dense_shape_value)
@@ -491,9 +491,10 @@ def gradients(ys,
def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,
gate_gradients, aggregation_method, stop_gradients):
"""Implementation of gradients()."""
- if context.in_eager_mode():
- raise RuntimeError("tf.gradients not supported in EAGER mode. Use "
- "functions in tf.contrib.eager.backprop instead.")
+ if context.executing_eagerly():
+ raise RuntimeError("tf.gradients not supported when eager execution "
+ "is enabled. Use tf.contrib.eager.GradientTape "
+ "instead.")
ys = _AsList(ys)
xs = _AsList(xs)
stop_gradients = [] if stop_gradients is None else _AsList(stop_gradients)
diff --git a/tensorflow/python/ops/io_ops.py b/tensorflow/python/ops/io_ops.py
index 7c782c12a5..f6a25610c5 100644
--- a/tensorflow/python/ops/io_ops.py
+++ b/tensorflow/python/ops/io_ops.py
@@ -173,7 +173,7 @@ class ReaderBase(object):
Raises:
RuntimeError: If eager execution is enabled.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError(
"Readers are not supported when eager execution is enabled. "
"Instead, please use tf.data to get data into your model.")
diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py
index baf7cc19fa..6f043f60e6 100644
--- a/tensorflow/python/ops/lookup_ops.py
+++ b/tensorflow/python/ops/lookup_ops.py
@@ -157,10 +157,10 @@ class InitializableLookupTableBase(LookupInterface):
default_value: The value to use if a key is missing in the table.
initializer: The table initializer to use.
"""
- if context.in_graph_mode():
- name = table_ref.op.name.split("/")[-1]
- else:
+ if context.executing_eagerly():
name = context.context().scope_name
+ else:
+ name = table_ref.op.name.split("/")[-1]
super(InitializableLookupTableBase,
self).__init__(initializer.key_dtype, initializer.value_dtype,
name)
@@ -521,7 +521,7 @@ class TextFileInitializer(TableInitializerBase):
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
# If the filename tensor is anything other than a string constant (e.g., if
# it is a placeholder) then it does not make sense to track it as an asset.
- if context.in_graph_mode() and constant_op.is_constant(filename):
+ if not context.executing_eagerly() and constant_op.is_constant(filename):
ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, filename)
return init_op
diff --git a/tensorflow/python/ops/losses/losses_impl.py b/tensorflow/python/ops/losses/losses_impl.py
index 0cae3c1453..424fd09e09 100644
--- a/tensorflow/python/ops/losses/losses_impl.py
+++ b/tensorflow/python/ops/losses/losses_impl.py
@@ -136,7 +136,7 @@ def _num_present(losses, weights, per_batch=False):
`[batch_size]`. Otherwise, a single scalar tensor is returned.
"""
if ((isinstance(weights, float) and weights != 0.0) or
- (context.in_eager_mode() and weights._rank() == 0 # pylint: disable=protected-access
+ (context.executing_eagerly() and weights._rank() == 0 # pylint: disable=protected-access
and not math_ops.equal(weights, 0.0))):
return _num_elements(losses)
with ops.name_scope(None, "num_present", (losses, weights)) as scope:
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index 55dd0c0e0d..e2ee9e4fe4 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -52,14 +52,14 @@ def _SumGrad(op, grad):
if axes is not None:
rank = len(input_0_shape)
if np.array_equal(axes, np.arange(rank)): # Reduce all dims.
- if context.in_graph_mode():
- new_shape = [1] * rank
- else:
+ if context.executing_eagerly():
ctx = context.context()
new_shape = ctx.ones_rank_cache().get(rank)
if new_shape is None:
new_shape = constant_op.constant([1] * rank, dtype=dtypes.int32)
ctx.ones_rank_cache().put(rank, new_shape)
+ else:
+ new_shape = [1] * rank
grad = array_ops.reshape(grad, new_shape)
# If shape is not fully defined (but rank is), we use Shape.
if None not in input_0_shape:
@@ -997,7 +997,7 @@ def _SparseMatMulGrad(op, grad):
op.inputs[0]: op.get_attr("a_is_sparse"),
op.inputs[1]: op.get_attr("b_is_sparse"),
# Use heuristic to figure out if grad might be sparse
- grad: context.in_graph_mode() and (grad.op.type == "ReluGrad")
+ grad: not context.executing_eagerly() and (grad.op.type == "ReluGrad")
}
def _SparseMatMul(t1, t2, out_dtype, transpose_a=False, transpose_b=False):
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index c019a5851f..5130c50717 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -2007,14 +2007,14 @@ def matmul(a,
if transpose_b and adjoint_b:
raise ValueError("Only one of transpose_b and adjoint_b can be True.")
- if context.in_graph_mode():
- a = ops.convert_to_tensor(a, name="a")
- b = ops.convert_to_tensor(b, name="b")
- else:
+ if context.executing_eagerly():
if not isinstance(a, (ops.EagerTensor, _resource_variable_type)):
a = ops.convert_to_tensor(a, name="a")
if not isinstance(b, (ops.EagerTensor, _resource_variable_type)):
b = ops.convert_to_tensor(b, name="b")
+ else:
+ a = ops.convert_to_tensor(a, name="a")
+ b = ops.convert_to_tensor(b, name="b")
# TODO(apassos) remove _shape_tuple here when it is not needed.
a_shape = a._shape_tuple() # pylint: disable=protected-access
@@ -2249,7 +2249,7 @@ def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None):
return inputs[0]
elif len(inputs) == 1 and name is not None:
return array_ops.identity(inputs[0], name=name)
- elif context.in_eager_mode():
+ elif context.executing_eagerly():
# TemporaryVariable not currently supported in eager mode; fall back
# onto AddN for now.
# TODO(frreiss) remove this once the lifetime of eager variables gets
diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py
index d314124ccd..9f85188b35 100644
--- a/tensorflow/python/ops/math_ops_test.py
+++ b/tensorflow/python/ops/math_ops_test.py
@@ -60,7 +60,7 @@ class ReduceTest(test_util.TensorFlowTestCase):
@test_util.run_in_graph_and_eager_modes()
def testReduceInvalidAxis(self):
- if context.in_eager_mode():
+ if context.executing_eagerly():
# The shape check is in run a graph construction time. In eager mode,
# it misses the check, magically return result given wrong shape.
return
@@ -249,7 +249,7 @@ class ScalarMulTest(test_util.TensorFlowTestCase):
@test_util.run_in_graph_and_eager_modes()
def testAcceptsRefs(self):
- if context.in_eager_mode():
+ if context.executing_eagerly():
var = resource_variable_ops.ResourceVariable(10, name="var")
else:
var = variables.Variable(10)
diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py
index 0123162b54..9ec4954579 100644
--- a/tensorflow/python/ops/metrics_impl.py
+++ b/tensorflow/python/ops/metrics_impl.py
@@ -308,7 +308,7 @@ def mean(values,
or tuple.
RuntimeError: If eager execution is enabled.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError('tf.metrics.mean is not supported when eager execution '
'is enabled.')
@@ -394,7 +394,7 @@ def accuracy(labels,
tuple.
RuntimeError: If eager execution is enabled.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError('tf.metrics.accuracy is not supported when eager '
'execution is enabled.')
@@ -644,7 +644,7 @@ def auc(labels,
tuple.
RuntimeError: If eager execution is enabled.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError('tf.metrics.auc is not supported when eager execution '
'is enabled.')
@@ -758,7 +758,7 @@ def mean_absolute_error(labels,
tuple.
RuntimeError: If eager execution is enabled.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError('tf.metrics.mean_absolute_error is not supported '
'when eager execution is enabled.')
@@ -818,7 +818,7 @@ def mean_cosine_distance(labels,
tuple.
RuntimeError: If eager execution is enabled.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError('tf.metrics.mean_cosine_distance is not supported when '
'eager execution is enabled.')
@@ -891,7 +891,7 @@ def mean_per_class_accuracy(labels,
tuple.
RuntimeError: If eager execution is enabled.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError('tf.metrics.mean_per_class_accuracy is not supported '
'when eager execution is enabled.')
@@ -996,7 +996,7 @@ def mean_iou(labels,
tuple.
RuntimeError: If eager execution is enabled.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError('tf.metrics.mean_iou is not supported when '
'eager execution is enabled.')
@@ -1098,7 +1098,7 @@ def mean_relative_error(labels,
tuple.
RuntimeError: If eager execution is enabled.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError('tf.metrics.mean_relative_error is not supported when '
'eager execution is enabled.')
@@ -1165,7 +1165,7 @@ def mean_squared_error(labels,
tuple.
RuntimeError: If eager execution is enabled.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError('tf.metrics.mean_squared_error is not supported when '
'eager execution is enabled.')
@@ -1223,7 +1223,7 @@ def mean_tensor(values,
or tuple.
RuntimeError: If eager execution is enabled.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError('tf.metrics.mean_tensor is not supported when '
'eager execution is enabled.')
@@ -1304,7 +1304,7 @@ def percentage_below(values,
or tuple.
RuntimeError: If eager execution is enabled.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError('tf.metrics.percentage_below is not supported when '
'eager execution is enabled.')
@@ -1397,7 +1397,7 @@ def false_negatives(labels,
or tuple.
RuntimeError: If eager execution is enabled.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError('tf.metrics.false_negatives is not supported when '
'eager execution is enabled.')
@@ -1453,7 +1453,7 @@ def false_negatives_at_thresholds(labels,
tuple.
RuntimeError: If eager execution is enabled.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError('tf.metrics.false_negatives_at_thresholds is not '
'supported when eager execution is enabled.')
@@ -1507,7 +1507,7 @@ def false_positives(labels,
tuple.
RuntimeError: If eager execution is enabled.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError('tf.metrics.false_positives is not supported when '
'eager execution is enabled.')
@@ -1563,7 +1563,7 @@ def false_positives_at_thresholds(labels,
tuple.
RuntimeError: If eager execution is enabled.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError('tf.metrics.false_positives_at_thresholds is not '
'supported when eager execution is enabled.')
@@ -1617,7 +1617,7 @@ def true_negatives(labels,
tuple.
RuntimeError: If eager execution is enabled.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError('tf.metrics.true_negatives is not '
'supported when eager execution is enabled.')
@@ -1673,7 +1673,7 @@ def true_negatives_at_thresholds(labels,
tuple.
RuntimeError: If eager execution is enabled.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError('tf.metrics.true_negatives_at_thresholds is not '
'supported when eager execution is enabled.')
@@ -1727,7 +1727,7 @@ def true_positives(labels,
tuple.
RuntimeError: If eager execution is enabled.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError('tf.metrics.true_positives is not '
'supported when eager execution is enabled.')
@@ -1783,7 +1783,7 @@ def true_positives_at_thresholds(labels,
tuple.
RuntimeError: If eager execution is enabled.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError('tf.metrics.true_positives_at_thresholds is not '
'supported when eager execution is enabled.')
@@ -1851,7 +1851,7 @@ def precision(labels,
tuple.
RuntimeError: If eager execution is enabled.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError('tf.metrics.precision is not '
'supported when eager execution is enabled.')
@@ -1947,7 +1947,7 @@ def precision_at_thresholds(labels,
tuple.
RuntimeError: If eager execution is enabled.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError('tf.metrics.precision_at_thresholds is not '
'supported when eager execution is enabled.')
@@ -2023,7 +2023,7 @@ def recall(labels,
tuple.
RuntimeError: If eager execution is enabled.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError('tf.metrics.recall is not supported is not '
'supported when eager execution is enabled.')
@@ -2400,7 +2400,7 @@ def recall_at_k(labels,
are not a list or tuple.
RuntimeError: If eager execution is enabled.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError('tf.metrics.recall_at_k is not '
'supported when eager execution is enabled.')
@@ -2549,7 +2549,7 @@ def recall_at_thresholds(labels,
tuple.
RuntimeError: If eager execution is enabled.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError('tf.metrics.recall_at_thresholds is not '
'supported when eager execution is enabled.')
@@ -2626,7 +2626,7 @@ def root_mean_squared_error(labels,
tuple.
RuntimeError: If eager execution is enabled.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError('tf.metrics.root_mean_squared_error is not '
'supported when eager execution is enabled.')
@@ -2707,7 +2707,7 @@ def sensitivity_at_specificity(labels,
or `updates_collections` are not a list or tuple.
RuntimeError: If eager execution is enabled.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError('tf.metrics.sensitivity_at_specificity is not '
'supported when eager execution is enabled.')
@@ -3098,7 +3098,7 @@ def average_precision_at_k(labels,
ValueError: if k is invalid.
RuntimeError: If eager execution is enabled.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError('tf.metrics.sparse_average_precision_at_k is not '
'supported when eager execution is enabled.')
@@ -3267,7 +3267,7 @@ def precision_at_top_k(labels,
are not a list or tuple.
RuntimeError: If eager execution is enabled.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError('tf.metrics.precision_at_top_k is not '
'supported when eager execution is enabled.')
@@ -3396,7 +3396,7 @@ def precision_at_k(labels,
are not a list or tuple.
RuntimeError: If eager execution is enabled.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError('tf.metrics.sparse_precision_at_k is not '
'supported when eager execution is enabled.')
@@ -3473,7 +3473,7 @@ def specificity_at_sensitivity(labels,
or `updates_collections` are not a list or tuple.
RuntimeError: If eager execution is enabled.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError('tf.metrics.specificity_at_sensitivity is not '
'supported when eager execution is enabled.')
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py
index 5582daf2da..4af5bd26dd 100644
--- a/tensorflow/python/ops/nn_grad.py
+++ b/tensorflow/python/ops/nn_grad.py
@@ -456,7 +456,7 @@ def _SoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad):
def IsZero(g):
# Some introspection to check if the gradient is feeding zeros
- if context.in_eager_mode():
+ if context.executing_eagerly():
# TODO(apassos) add an efficient way to detect eager zeros here.
return False
if g.op.type in ("ZerosLike", "Zeros"):
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index 66a05f2228..fb3fe77b4d 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -1504,7 +1504,7 @@ def bias_add(value, bias, data_format=None, name=None):
A `Tensor` with the same type as `value`.
"""
with ops.name_scope(name, "BiasAdd", [value, bias]) as name:
- if context.in_graph_mode():
+ if not context.executing_eagerly():
value = ops.convert_to_tensor(value, name="input")
bias = ops.convert_to_tensor(bias, dtype=value.dtype, name="bias")
return gen_nn_ops.bias_add(value, bias, data_format=data_format, name=name)
@@ -1616,7 +1616,7 @@ def _flatten_outer_dims(logits):
output = array_ops.reshape(logits, array_ops.concat([[-1], last_dim_size], 0))
# Set output shape if known.
- if context.in_graph_mode():
+ if not context.executing_eagerly():
shape = logits.get_shape()
if shape is not None and shape.dims is not None:
shape = shape.as_list()
@@ -1881,7 +1881,8 @@ def softmax_cross_entropy_with_logits_v2(
# Make shape inference work since reshape and transpose may erase its static
# shape.
- if context.in_graph_mode() and shape is not None and shape.dims is not None:
+ if not context.executing_eagerly(
+ ) and shape is not None and shape.dims is not None:
shape = shape.as_list()
del shape[dim]
cost.set_shape(shape)
@@ -2318,7 +2319,7 @@ def dropout(x, keep_prob, noise_shape=None, seed=None, name=None): # pylint: di
# 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
binary_tensor = math_ops.floor(random_tensor)
ret = math_ops.div(x, keep_prob) * binary_tensor
- if context.in_graph_mode():
+ if not context.executing_eagerly():
ret.set_shape(x.get_shape())
return ret
diff --git a/tensorflow/python/ops/numerics.py b/tensorflow/python/ops/numerics.py
index b4ce1cbf25..d348e47f57 100644
--- a/tensorflow/python/ops/numerics.py
+++ b/tensorflow/python/ops/numerics.py
@@ -74,7 +74,7 @@ def add_check_numerics_ops():
the checked operations.
@enc_compatibility
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError(
"add_check_numerics_ops() is not compatible with eager execution. "
"To check for Inf's and NaN's under eager execution, call "
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index d0578f8205..54191ee765 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -135,10 +135,10 @@ class EagerResourceDeleter(object):
# valid, and so on. Printing warnings in these cases is silly
# (exceptions raised from __del__ are printed as warnings to stderr).
pass # 'NoneType' object is not callable when the handle has been
- # partially unloaded.
+ # partially unloaded.
except AttributeError:
pass # 'NoneType' object has no attribute 'eager_mode' when context has
- # been unloaded. Will catch other module unloads as well.
+ # been unloaded. Will catch other module unloads as well.
def shape_safe_assign_variable_handle(handle, shape, value, name=None):
@@ -267,9 +267,9 @@ class ResourceVariable(variables.Variable):
if initial_value is not None:
raise ValueError("variable_def and initial_value are mutually "
"exclusive.")
- if not context.in_graph_mode():
- raise ValueError("Creating ResourceVariable from variable_def"
- " only supported in GRAPH mode.")
+ if context.executing_eagerly():
+ raise ValueError("Creating ResourceVariable from variable_def is "
+ "not supported when eager execution is enabled.")
self._init_from_proto(variable_def, import_scope=import_scope)
else:
self._init_from_args(
@@ -363,7 +363,7 @@ class ResourceVariable(variables.Variable):
# this graph.
self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access
with ops.init_scope():
- self._in_graph_mode = context.in_graph_mode()
+ self._in_graph_mode = not context.executing_eagerly()
with ops.name_scope(name, "Variable", []
if init_from_fn else [initial_value]) as name:
# pylint: disable=protected-access
@@ -470,7 +470,7 @@ class ResourceVariable(variables.Variable):
self._cached_value = self._read_variable_op()
else:
self._cached_value = None
- if context.in_graph_mode():
+ if not context.executing_eagerly():
ops.add_to_collections(collections, self)
elif ops.GraphKeys.GLOBAL_STEP in collections:
ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, self)
@@ -489,7 +489,7 @@ class ResourceVariable(variables.Variable):
def _init_from_proto(self, variable_def, import_scope=None):
"""Initializes from `VariableDef` proto."""
# Note that init_from_proto is currently not supported in Eager mode.
- assert context.in_graph_mode()
+ assert not context.executing_eagerly()
self._in_graph_mode = True
assert isinstance(variable_def, variable_pb2.VariableDef)
if not variable_def.is_resource:
@@ -582,7 +582,8 @@ class ResourceVariable(variables.Variable):
def create(self):
"""The op responsible for initializing this variable."""
if not self._in_graph_mode:
- raise RuntimeError("Calling create in EAGER mode not supported.")
+ raise RuntimeError("Calling create is not supported when eager execution"
+ " is enabled.")
return self._initializer_op
@property
@@ -610,7 +611,7 @@ class ResourceVariable(variables.Variable):
@property
def initial_value(self):
"""Returns the Tensor used as the initial value for the variable."""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError("initial_value not supported in EAGER mode.")
return self._initial_value
@@ -631,15 +632,15 @@ class ResourceVariable(variables.Variable):
def eval(self, session=None):
"""Evaluates and returns the value of this variable."""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError("Trying to eval in EAGER mode")
return self._graph_element.eval(session=session)
def numpy(self):
- if context.in_graph_mode():
- raise NotImplementedError(
- "numpy() is only available when eager execution is enabled.")
- return self.read_value().numpy()
+ if context.executing_eagerly():
+ return self.read_value().numpy()
+ raise NotImplementedError(
+ "numpy() is only available when eager execution is enabled.")
def count_up_to(self, limit):
"""Increments this variable until it reaches `limit`.
@@ -720,7 +721,7 @@ class ResourceVariable(variables.Variable):
A `VariableDef` protocol buffer, or `None` if the `Variable` is not
in the specified name scope.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError("to_proto not supported in EAGER mode.")
if export_scope is None or self.handle.name.startswith(export_scope):
var_def = variable_pb2.VariableDef()
@@ -747,7 +748,7 @@ class ResourceVariable(variables.Variable):
@staticmethod
def from_proto(variable_def, import_scope=None):
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError("from_proto not supported in EAGER mode.")
return ResourceVariable(
variable_def=variable_def, import_scope=import_scope)
@@ -984,10 +985,10 @@ class _UnreadVariable(ResourceVariable):
self._is_initialized_op = None
self._initializer_op = None
self._parent_op = parent_op
- if context.in_graph_mode():
- self._graph_element = self.read_value()
- else:
+ if context.executing_eagerly():
self._graph_element = None
+ else:
+ self._graph_element = self.read_value()
self._handle_deleter = deleter
def value(self):
diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py
index aa8d4327d2..625d433b1f 100644
--- a/tensorflow/python/ops/rnn.py
+++ b/tensorflow/python/ops/rnn.py
@@ -575,7 +575,7 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,
# Create a new scope in which the caching device is either
# determined by the parent scope, or is set to place the cached
# Variable using the same placement as for the rest of the RNN.
- if context.in_graph_mode():
+ if not context.executing_eagerly():
if varscope.caching_device is None:
varscope.set_caching_device(lambda op: op.device)
@@ -616,7 +616,7 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,
["Expected shape for Tensor %s is " % x.name,
packed_shape, " but saw shape: ", x_shape])
- if context.in_graph_mode() and sequence_length is not None:
+ if not context.executing_eagerly() and sequence_length is not None:
# Perform some shape validation
with ops.control_dependencies(
[_assert_has_shape(sequence_length, [batch_size])]):
@@ -742,7 +742,7 @@ def _dynamic_rnn_loop(cell,
element_shape=element_shape,
tensor_array_name=base_name + name)
- in_graph_mode = context.in_graph_mode()
+ in_graph_mode = not context.executing_eagerly()
if in_graph_mode:
output_ta = tuple(
_create_ta(
@@ -1027,7 +1027,7 @@ def raw_rnn(cell, loop_fn,
# determined by the parent scope, or is set to place the cached
# Variable using the same placement as for the rest of the RNN.
with vs.variable_scope(scope or "rnn") as varscope:
- if context.in_graph_mode():
+ if not context.executing_eagerly():
if varscope.caching_device is None:
varscope.set_caching_device(lambda op: op.device)
@@ -1242,7 +1242,7 @@ def static_rnn(cell,
# determined by the parent scope, or is set to place the cached
# Variable using the same placement as for the rest of the RNN.
with vs.variable_scope(scope or "rnn") as varscope:
- if context.in_graph_mode():
+ if not context.executing_eagerly():
if varscope.caching_device is None:
varscope.set_caching_device(lambda op: op.device)
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index 3ae1d1184d..e61d10835f 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -128,7 +128,7 @@ def _zero_state_tensors(state_size, batch_size, dtype):
"""Combine s with batch_size to get a proper tensor shape."""
c = _concat(batch_size, s)
size = array_ops.zeros(c, dtype=dtype)
- if context.in_graph_mode():
+ if not context.executing_eagerly():
c_static = _concat(batch_size, s, static=True)
size.set_shape(c_static)
return size
@@ -192,12 +192,13 @@ class RNNCell(base_layer.Layer):
def _rnn_get_variable(self, getter, *args, **kwargs):
variable = getter(*args, **kwargs)
- if context.in_graph_mode():
- trainable = (variable in tf_variables.trainable_variables() or
- (isinstance(variable, tf_variables.PartitionedVariable) and
- list(variable)[0] in tf_variables.trainable_variables()))
- else:
+ if context.executing_eagerly():
trainable = variable._trainable # pylint: disable=protected-access
+ else:
+ trainable = (
+ variable in tf_variables.trainable_variables() or
+ (isinstance(variable, tf_variables.PartitionedVariable) and
+ list(variable)[0] in tf_variables.trainable_variables()))
if trainable and variable not in self._trainable_weights:
self._trainable_weights.append(variable)
elif not trainable and variable not in self._non_trainable_weights:
@@ -241,7 +242,7 @@ class RNNCell(base_layer.Layer):
# Try to use the last cached zero_state. This is done to avoid recreating
# zeros, especially when eager execution is enabled.
state_size = self.state_size
- is_eager = context.in_eager_mode()
+ is_eager = context.executing_eagerly()
if is_eager and hasattr(self, "_last_zero_state"):
(last_state_size, last_batch_size, last_dtype,
last_output) = getattr(self, "_last_zero_state")
diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py
index 01f0b81684..529eebe769 100644
--- a/tensorflow/python/ops/script_ops.py
+++ b/tensorflow/python/ops/script_ops.py
@@ -317,7 +317,7 @@ def py_func(func, inp, Tout, stateful=True, name=None):
Returns:
A list of `Tensor` or a single `Tensor` which `func` computes.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
result = func(*[x.numpy() for x in inp])
result = nest.flatten(result)
diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py
index fd4419640a..c3ad5831b4 100644
--- a/tensorflow/python/ops/state_ops.py
+++ b/tensorflow/python/ops/state_ops.py
@@ -186,7 +186,7 @@ def is_variable_initialized(ref, name=None):
if ref.dtype._is_ref_dtype:
return gen_state_ops.is_variable_initialized(ref=ref, name=name)
# Handle resource variables.
- if context.in_eager_mode() or ref.op.type == "VarHandleOp":
+ if context.executing_eagerly() or ref.op.type == "VarHandleOp":
return gen_resource_variable_ops.var_is_initialized_op(ref.handle,
name=name)
diff --git a/tensorflow/python/ops/template.py b/tensorflow/python/ops/template.py
index 70e8040512..0a391d896a 100644
--- a/tensorflow/python/ops/template.py
+++ b/tensorflow/python/ops/template.py
@@ -204,7 +204,7 @@ def make_template_internal(name_,
if kwargs:
func_ = tf_decorator.make_decorator(func_, functools.partial(
func_, **kwargs))
- if context.in_eager_mode():
+ if context.executing_eagerly():
if unique_name_ is not None:
raise ValueError(
"unique_name_ cannot be used when eager exeuction is enabled.")
@@ -364,7 +364,7 @@ class Template(checkpointable.CheckpointableBase):
"""
def _call_next_creator_renaming_initializer(initializer, **inner_kwargs):
inner_kwargs.pop("name") # Ignored; this is the scope-stripped name which
- # we don't want to propagate.
+ # we don't want to propagate.
return next_creator(
initial_value=initializer,
name=name,
@@ -647,7 +647,7 @@ class EagerTemplate(Template):
Raises:
RuntimeError: if eager execution is not enabled.
"""
- if not context.in_eager_mode():
+ if not context.executing_eagerly():
raise RuntimeError(
"{} objects can only be used when eager execution is enabled, use "
"tf.Template for graph construction".
diff --git a/tensorflow/python/ops/tensor_array_ops.py b/tensorflow/python/ops/tensor_array_ops.py
index 6226f426be..2f6badcb53 100644
--- a/tensorflow/python/ops/tensor_array_ops.py
+++ b/tensorflow/python/ops/tensor_array_ops.py
@@ -338,7 +338,7 @@ class _GraphTensorArray(object):
with ops.name_scope(name, "TensorArrayScatter",
[self._handle, value, indices]):
value = ops.convert_to_tensor(value, name="value")
- if self._infer_shape and context.in_graph_mode():
+ if self._infer_shape and not context.executing_eagerly():
self._merge_element_shape(value.shape[1:])
with self._maybe_colocate_with(value):
flow_out = gen_data_flow_ops.tensor_array_scatter_v3(
@@ -363,7 +363,7 @@ class _GraphTensorArray(object):
value = ops.convert_to_tensor(value, name="value")
with self._maybe_colocate_with(value):
lengths_64 = math_ops.to_int64(lengths)
- if self._infer_shape and context.in_graph_mode():
+ if self._infer_shape and not context.executing_eagerly():
clengths = tensor_util.constant_value(lengths_64)
if value.shape.dims is not None:
if clengths is not None and clengths.max() == clengths.min():
@@ -774,10 +774,10 @@ class TensorArray(object):
ValueError: if both handle and tensor_array_name are provided.
TypeError: if handle is provided but is not a Tensor.
"""
- if context.in_graph_mode():
- implementation = _GraphTensorArray
- else:
+ if context.executing_eagerly():
implementation = _EagerTensorArray
+ else:
+ implementation = _GraphTensorArray
self._implementation = implementation(
dtype,
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index de4e44f60c..7f650ff6a9 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -321,7 +321,7 @@ class _VariableStore(object):
raise ValueError(
"Passed a custom_getter which is not callable: %s" % custom_getter)
- if context.in_eager_mode():
+ if context.executing_eagerly():
if not self._store_eager_variables and reuse:
raise RuntimeError(
"When eager execution is enabled variable reuse is only supported"
@@ -518,7 +518,7 @@ class _VariableStore(object):
when violating reuse during variable creation, or if an existing
sharded variable exists for the given name but with different sharding.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise NotImplementedError("Partitioned variables are not yet supported "
"when eager execution is enabled.")
@@ -798,7 +798,7 @@ class _VariableStore(object):
validate_shape=validate_shape,
constraint=constraint,
use_resource=use_resource)
- if context.in_graph_mode() or self._store_eager_variables:
+ if not context.executing_eagerly() or self._store_eager_variables:
# In eager mode we do not want to keep default references to Variable
# objects as this will prevent their memory from being released.
self._vars[name] = v
@@ -811,12 +811,12 @@ class _VariableStore(object):
with ops.name_scope(name + "/Regularizer/"):
loss = regularizer(v)
if loss is not None:
- if context.in_graph_mode():
- v_name = v.name
- loss_name = loss.name
- else:
+ if context.executing_eagerly():
v_name = "v_%s" % type(v)
loss_name = "loss_%s" % type(loss)
+ else:
+ v_name = v.name
+ loss_name = loss.name
logging.vlog(1, "Applied regularizer to %s and added the result %s "
"to REGULARIZATION_LOSSES.", v_name, loss_name)
ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES, loss)
@@ -920,7 +920,7 @@ class VariableScope(object):
self._dtype = dtype
self._use_resource = use_resource
self._constraint = constraint
- if context.in_eager_mode():
+ if context.executing_eagerly():
if self._caching_device is not None:
raise NotImplementedError("Caching devices is not yet supported "
"when eager execution is enabled.")
@@ -988,7 +988,7 @@ class VariableScope(object):
def set_use_resource(self, use_resource):
"""Sets whether to use ResourceVariables for this scope."""
- if context.in_eager_mode() and not use_resource:
+ if context.executing_eagerly() and not use_resource:
raise ValueError("When eager execution is enabled, "
"use_resource cannot be set to false.")
self._use_resource = use_resource
@@ -999,14 +999,14 @@ class VariableScope(object):
def set_caching_device(self, caching_device):
"""Set caching_device for this scope."""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise NotImplementedError("Caching devices are not yet supported "
"when eager execution is enabled.")
self._caching_device = caching_device
def set_partitioner(self, partitioner):
"""Set partitioner for this scope."""
- if partitioner and context.in_eager_mode():
+ if partitioner and context.executing_eagerly():
raise NotImplementedError("Partitioned variables are not yet supported "
"when eager execution is enabled.")
self._partitioner = partitioner
@@ -1057,14 +1057,14 @@ class VariableScope(object):
partitioner = self._partitioner
if custom_getter is None:
custom_getter = self._custom_getter
- if context.in_graph_mode():
+ if context.executing_eagerly():
+ reuse = False
+ use_resource = True
+ else:
if reuse is None:
reuse = self._reuse
if use_resource is None:
use_resource = self._use_resource
- else:
- reuse = False
- use_resource = True
full_name = self.name + "/" + name if self.name else name
# Variable names only depend on variable_scope (full_name here),
@@ -1107,7 +1107,7 @@ class VariableScope(object):
use_resource=None,
constraint=None):
"""Gets an existing variable with this name or create a new one."""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise NotImplementedError("Partitioned variables are not yet supported "
"when eager execution is enabled.")
if initializer is None:
@@ -1871,7 +1871,7 @@ class variable_scope(object):
raise ValueError("The reuse parameter must be True or False or None.")
if self._values is None:
self._values = []
- self._in_graph_mode = not context.in_eager_mode()
+ self._in_graph_mode = not context.executing_eagerly()
if self._in_graph_mode:
self._graph = ops._get_graph_from_inputs(self._values) # pylint: disable=protected-access
self._cached_pure_variable_scope = None
@@ -2111,13 +2111,13 @@ def default_variable_creator(next_creator=None, **kwargs):
use_resource = kwargs.get("use_resource", None)
if use_resource is None:
use_resource = get_variable_scope().use_resource
- if use_resource or (use_resource is None and context.in_eager_mode()):
+ if use_resource or (use_resource is None and context.executing_eagerly()):
return resource_variable_ops.ResourceVariable(
initial_value=initial_value, trainable=trainable,
collections=collections, validate_shape=validate_shape,
caching_device=caching_device, name=name, dtype=dtype,
constraint=constraint)
- elif not use_resource and context.in_eager_mode():
+ elif not use_resource and context.executing_eagerly():
raise RuntimeError(
"VariableScope should use resource variable when eager execution is"
" enabled, but use_resource is False."
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index 643a3b7edc..5b9947f441 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -210,10 +210,11 @@ class Variable(checkpointable.CheckpointableBase):
for details on how variables work in eager execution.
@end_compatibility
"""
- if not context.in_graph_mode():
- raise RuntimeError("tf.Variable not supported in Eager mode. "
- "Please use tfe.Variable instead")
- self._in_graph_mode = context.in_graph_mode()
+ if context.executing_eagerly():
+ raise RuntimeError(
+ "tf.Variable not supported when eager execution is enabled. "
+ "Please use tf.contrib.eager.Variable instead")
+ self._in_graph_mode = True
if variable_def:
# If variable_def is provided, recreates the variable from its fields.
if initial_value:
@@ -234,7 +235,7 @@ class Variable(checkpointable.CheckpointableBase):
constraint=constraint)
def __repr__(self):
- if context.in_eager_mode():
+ if context.executing_eagerly():
return "<tf.Variable '%s' shape=%s dtype=%s, numpy=%s>" % (
self.name, self.get_shape(), self.dtype.name,
ops.numpy_text(self.read_value(), is_repr=True))
@@ -740,15 +741,15 @@ class Variable(checkpointable.CheckpointableBase):
Raises:
ValueError: Session is not passed and no default session
"""
- if context.in_graph_mode():
+ if context.executing_eagerly():
+ self.assign(value)
+ else:
session = session or ops.get_default_session()
if session is None:
raise ValueError(
"Either session argument should be provided or default session "
"should be established")
session.run(self._initializer_op, {self._initializer_op.inputs[1]: value})
- else:
- self.assign(value)
# Conversion to tensor.
@staticmethod
@@ -1248,9 +1249,9 @@ class PartitionedVariable(object):
information does not match `shape`, or `partitions` has invalid values.
RuntimeError: If eager execution is enabled
"""
- if not context.in_graph_mode():
- raise RuntimeError("tf.PartitionedVariable not supported in "
- "eager mode. Please use tfe.Variable instead")
+ if context.executing_eagerly():
+ raise RuntimeError(
+ "tf.PartitionedVariable not supported with eager execution enabled.")
if not isinstance(variable_list, (list, tuple)):
raise TypeError(
"variable_list is not a list or tuple: %s" % variable_list)
@@ -1541,7 +1542,7 @@ def variables_initializer(var_list, name="init"):
Returns:
An Op that run the initializers of all the specified variables.
"""
- if var_list and context.in_graph_mode():
+ if var_list and not context.executing_eagerly():
return control_flow_ops.group(*[v.initializer for v in var_list], name=name)
return control_flow_ops.no_op(name=name)
@@ -1563,7 +1564,7 @@ def global_variables_initializer():
Returns:
An Op that initializes global variables in the graph.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
return control_flow_ops.no_op(name="global_variables_initializer")
return variables_initializer(global_variables())
@@ -1585,7 +1586,7 @@ def local_variables_initializer():
Returns:
An Op that initializes all local variables in the graph.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
return control_flow_ops.no_op(name="local_variables_initializer")
return variables_initializer(local_variables())