aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/layers
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2018-03-07 12:03:56 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-07 12:10:42 -0800
commit37cef895bfe06913477b87917cbee7284aefa7cd (patch)
tree4f05a013578c0459a52fc5e6448bb3dfc2d04971 /tensorflow/python/layers
parent808b569e85df8d63590740f05bc14d964efc4801 (diff)
eager: Rename in_eager_mode to executing_eagerly and get rid of in_graph_mode.
This is in preparation to introduce one public, stable symbol: tf.executing_eagerly() (i.e., part of moving APIs related to eager execution from "contrib" to a namespace where we provide API stability guarantees) PiperOrigin-RevId: 188212646
Diffstat (limited to 'tensorflow/python/layers')
-rw-r--r--tensorflow/python/layers/base.py65
-rw-r--r--tensorflow/python/layers/base_test.py32
-rw-r--r--tensorflow/python/layers/convolutional.py4
-rw-r--r--tensorflow/python/layers/core.py4
-rw-r--r--tensorflow/python/layers/core_test.py12
-rw-r--r--tensorflow/python/layers/normalization.py16
6 files changed, 68 insertions, 65 deletions
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index 15f72786de..e9066d3fda 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -115,7 +115,7 @@ class Layer(checkpointable.CheckpointableBase):
# Provides information about which inputs are compatible with the layer.
self.input_spec = None
- if activity_regularizer and context.in_eager_mode():
+ if activity_regularizer and context.executing_eagerly():
raise ValueError(
('Activity regularization is not supported when executing eagerly. '
'Got activity_regularizer=%s') % (activity_regularizer,))
@@ -228,7 +228,7 @@ class Layer(checkpointable.CheckpointableBase):
@property
def updates(self):
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError('Layer.updates not supported in Eager mode.')
if not self.trainable and not self.stateful:
return []
@@ -260,7 +260,7 @@ class Layer(checkpointable.CheckpointableBase):
have is available at runtime.
A step counter might fall into this category.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
return # Updates already applied when in eager mode.
updates = _to_list(updates)
@@ -286,7 +286,7 @@ class Layer(checkpointable.CheckpointableBase):
Raises:
RuntimeError: If called in Eager mode.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError('`get_updates_for()` not supported in Eager mode.')
# Updates disabled if layer is not trainable and not explicitly stateful.
@@ -317,7 +317,7 @@ class Layer(checkpointable.CheckpointableBase):
Returns:
A list of tensors.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
# _losses may only contain variable regularization losses when executing
# eagerly, and they have been saved as lambdas to be executed when
# requested.
@@ -355,7 +355,7 @@ class Layer(checkpointable.CheckpointableBase):
Raises:
RuntimeError: If called in Eager mode.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
# TODO(fchollet): it should be possible (and highly desirable) to support
# `add_loss` in eager mode. This allows great convenience and flexibility
# in defining custom losses on the fly (e.g. in VAEs).
@@ -389,7 +389,7 @@ class Layer(checkpointable.CheckpointableBase):
Raises:
RuntimeError: If called in Eager mode.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError('Layer.get_losses_for not supported in Eager mode.')
if inputs is None:
@@ -509,7 +509,7 @@ class Layer(checkpointable.CheckpointableBase):
# will occur; it should be None if and only if initialization will take
# place in the eager context.
init_graph = None
- if context.in_graph_mode():
+ if not context.executing_eagerly():
default_graph = ops.get_default_graph()
if default_graph.building_function:
with ops.init_scope():
@@ -517,7 +517,7 @@ class Layer(checkpointable.CheckpointableBase):
# will be lifted; if initialization ops will be lifted into
# the eager context, then there is nothing to retrieve, since variable
# collections are not supported when eager execution is enabled.
- if context.in_graph_mode():
+ if not context.executing_eagerly():
init_graph = ops.get_default_graph()
existing_variables = set(tf_variables.global_variables())
else:
@@ -624,17 +624,17 @@ class Layer(checkpointable.CheckpointableBase):
self._set_scope(kwargs.pop('scope', None))
input_list = nest.flatten(inputs)
- in_graph_mode = context.in_graph_mode()
+ build_graph = not context.executing_eagerly()
in_deferred_mode = isinstance(input_list[0], _DeferredTensor)
# Ensure the Layer, if being reused, is working with inputs from
# the same graph as where it was created.
- if in_graph_mode:
+ if build_graph:
try:
# Set layer's "graph" at build time
self._graph = ops._get_graph_from_inputs(input_list, graph=self._graph) # pylint: disable=protected-access
except ValueError as e:
raise ValueError('Input graph and Layer graph are not the same: %s' % e)
- if in_graph_mode or in_deferred_mode:
+ if build_graph or in_deferred_mode:
user_kwargs = copy.copy(kwargs)
# Handle Keras mask propagation from previous layer to current layer.
@@ -669,13 +669,14 @@ class Layer(checkpointable.CheckpointableBase):
with scope_context_manager as scope:
with ops.name_scope(self._name_scope_name(scope)):
if not self.built:
- if not in_graph_mode:
+ if not build_graph:
# Activity regularization is currently unsupported in Eager mode.
if self._activity_regularizer:
- raise ValueError('activity_regularizer currently unsupported in '
- 'Eager mode. Found an activity_regularizer in '
- '%s(%s).' % (self.__class__.__name__, self))
- if not in_graph_mode and not in_deferred_mode:
+ raise ValueError(
+ 'activity_regularizer currently unsupported with '
+ 'eager execution enabled. Found an activity_regularizer in '
+ '%s(%s).' % (self.__class__.__name__, self))
+ if not build_graph and not in_deferred_mode:
# TODO(agarwal): support _keras_history in Eager mode.
for x in input_list:
if hasattr(x, '_keras_history'):
@@ -706,7 +707,7 @@ class Layer(checkpointable.CheckpointableBase):
if call_has_scope_arg:
kwargs['scope'] = scope
# Check input assumptions set after layer building, e.g. input shape.
- if in_graph_mode or in_deferred_mode:
+ if build_graph or in_deferred_mode:
self._assert_input_compatibility(inputs)
if not in_deferred_mode:
@@ -730,7 +731,7 @@ class Layer(checkpointable.CheckpointableBase):
if len(outputs) == 1:
outputs = outputs[0]
- if in_graph_mode:
+ if build_graph:
# Apply activity regularization.
# Note that it should be applied every time the layer creates a new
# output, since it is output-specific.
@@ -752,7 +753,7 @@ class Layer(checkpointable.CheckpointableBase):
else:
outputs._keras_mask = output_mask # pylint: disable=protected-access
- if in_graph_mode:
+ if build_graph:
# If all input tensors have history metadata,
# we update the output tensors
# with corresponding history metadata, thus eventually allowing to use
@@ -775,7 +776,7 @@ class Layer(checkpointable.CheckpointableBase):
# Update global default collections.
_add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS)
- if in_deferred_mode or in_graph_mode:
+ if in_deferred_mode or build_graph:
if _have_all_keras_metadata(inputs):
# Add an inbound node to the layer, so it can keep track of this call.
# This updates the layer history of the output tensor(s).
@@ -787,7 +788,7 @@ class Layer(checkpointable.CheckpointableBase):
@property
def graph(self):
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError('Layer.graph not supported in Eager mode.')
return self._graph
@@ -891,7 +892,7 @@ class Layer(checkpointable.CheckpointableBase):
mode.
ValueError: If the index provided does not match any node.
"""
- assert context.in_graph_mode()
+ assert not context.executing_eagerly()
if not self._inbound_nodes:
raise RuntimeError('The layer has never been called '
'and thus has no defined ' + attr_name + '.')
@@ -921,7 +922,7 @@ class Layer(checkpointable.CheckpointableBase):
Raises:
RuntimeError: If called in Eager mode.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError(
'Layer.get_input_shape_at not supported in Eager mode.')
return self._get_node_attribute_at_index(node_index, 'input_shapes',
@@ -943,7 +944,7 @@ class Layer(checkpointable.CheckpointableBase):
Raises:
RuntimeError: If called in Eager mode.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError(
'Layer.get_output_shape_at not supported in Eager mode.')
return self._get_node_attribute_at_index(node_index, 'output_shapes',
@@ -964,7 +965,7 @@ class Layer(checkpointable.CheckpointableBase):
Raises:
RuntimeError: If called in Eager mode.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError('Layer.get_input_at not supported in Eager mode.')
return self._get_node_attribute_at_index(node_index, 'input_tensors',
'input')
@@ -984,7 +985,7 @@ class Layer(checkpointable.CheckpointableBase):
Raises:
RuntimeError: If called in Eager mode.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError('Layer.get_output_at not supported in Eager mode.')
return self._get_node_attribute_at_index(node_index, 'output_tensors',
'output')
@@ -1007,7 +1008,7 @@ class Layer(checkpointable.CheckpointableBase):
RuntimeError: If called in Eager mode.
AttributeError: If no inbound nodes are found.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError('Layer.input not supported in Eager mode.')
if not self._inbound_nodes:
raise AttributeError('Layer ' + self.name +
@@ -1029,7 +1030,7 @@ class Layer(checkpointable.CheckpointableBase):
layers.
RuntimeError: if called in Eager mode.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError('Layer.output not supported in Eager mode.')
if not self._inbound_nodes:
raise AttributeError('Layer ' + self.name + ' has no inbound nodes.')
@@ -1051,7 +1052,7 @@ class Layer(checkpointable.CheckpointableBase):
AttributeError: if the layer has no defined input_shape.
RuntimeError: if called in Eager mode.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError('Layer.input_shape not supported in Eager mode.')
if not self._inbound_nodes:
raise AttributeError('The layer has never been called '
@@ -1112,7 +1113,7 @@ class Layer(checkpointable.CheckpointableBase):
AttributeError: if the layer has no defined output shape.
RuntimeError: if called in Eager mode.
"""
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError('Layer.output_shape not supported in Eager mode.')
if not self._inbound_nodes:
raise AttributeError('The layer has never been called '
@@ -1470,7 +1471,7 @@ def _to_list(x):
def _add_elements_to_collection(elements, collection_list):
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError('Using collections from Layers not supported in Eager '
'mode. Tried to add %s to %s' % (elements,
collection_list))
diff --git a/tensorflow/python/layers/base_test.py b/tensorflow/python/layers/base_test.py
index 1ee9ec7f7a..9ed4afeaba 100644
--- a/tensorflow/python/layers/base_test.py
+++ b/tensorflow/python/layers/base_test.py
@@ -44,7 +44,7 @@ class BaseLayerTest(test.TestCase):
self.assertEqual(layer.variables, [])
self.assertEqual(layer.trainable_variables, [])
self.assertEqual(layer.non_trainable_variables, [])
- if context.in_graph_mode():
+ if not context.executing_eagerly():
# updates, losses only supported in GRAPH mode
self.assertEqual(layer.updates, [])
self.assertEqual(layer.losses, [])
@@ -63,7 +63,7 @@ class BaseLayerTest(test.TestCase):
self.assertEqual(layer.variables, [variable])
self.assertEqual(layer.trainable_variables, [variable])
self.assertEqual(layer.non_trainable_variables, [])
- if context.in_graph_mode():
+ if not context.executing_eagerly():
self.assertEqual(
layer.variables,
ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
@@ -77,7 +77,7 @@ class BaseLayerTest(test.TestCase):
self.assertEqual(layer.variables, [variable, variable_2])
self.assertEqual(layer.trainable_variables, [variable])
self.assertEqual(layer.non_trainable_variables, [variable_2])
- if context.in_graph_mode():
+ if not context.executing_eagerly():
self.assertEqual(
len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 1)
@@ -161,7 +161,7 @@ class BaseLayerTest(test.TestCase):
inputs = random_ops.random_uniform((5,), seed=1)
outputs = layer.apply(inputs)
self.assertEqual(layer.built, True)
- if context.in_graph_mode():
+ if not context.executing_eagerly():
# op is only supported in GRAPH mode
self.assertEqual(outputs.op.name, 'my_layer/Square')
@@ -210,7 +210,7 @@ class BaseLayerTest(test.TestCase):
inputs = random_ops.random_uniform((5,), seed=1)
outputs = layer.apply(inputs)
self.assertEqual(layer.built, True)
- if context.in_graph_mode():
+ if not context.executing_eagerly():
# op only supported in GRAPH mode.
self.assertEqual(outputs.op.name, 'my_layer/Square')
@@ -280,7 +280,7 @@ class BaseLayerTest(test.TestCase):
def call(self, inputs):
return inputs
- if context.in_graph_mode():
+ if not context.executing_eagerly():
layer = CustomerLayer()
with self.assertRaisesRegexp(ValueError, r'requires a defined rank'):
layer.apply(array_ops.placeholder('int32'))
@@ -307,7 +307,7 @@ class BaseLayerTest(test.TestCase):
def call(self, inputs):
return inputs
- if context.in_graph_mode():
+ if not context.executing_eagerly():
layer = CustomerLayer()
with self.assertRaisesRegexp(ValueError, r'requires a defined rank'):
layer.apply(array_ops.placeholder('int32'))
@@ -335,7 +335,7 @@ class BaseLayerTest(test.TestCase):
def call(self, inputs):
return inputs
- if context.in_graph_mode():
+ if not context.executing_eagerly():
layer = CustomerLayer()
with self.assertRaisesRegexp(ValueError, r'requires a defined rank'):
layer.apply(array_ops.placeholder('int32'))
@@ -430,7 +430,7 @@ class BaseLayerTest(test.TestCase):
layer.apply(constant_op.constant(1))
# Works
- if context.in_graph_mode():
+ if not context.executing_eagerly():
layer.apply(array_ops.placeholder('int32'))
layer.apply(array_ops.placeholder('int32', shape=(2, 3)))
@@ -453,13 +453,7 @@ class BaseLayerTest(test.TestCase):
return {'l' + key: inputs[key] for key in inputs}
layer = DictLayer()
- if context.in_graph_mode():
- i1 = array_ops.placeholder('int32')
- i2 = array_ops.placeholder('float32')
- result = layer.apply({'abel': i1, 'ogits': i2})
- self.assertTrue(isinstance(result, dict))
- self.assertEqual(set(['label', 'logits']), set(result.keys()))
- else:
+ if context.executing_eagerly():
i1 = constant_op.constant(3)
i2 = constant_op.constant(4.0)
result = layer.apply({'abel': i1, 'ogits': i2})
@@ -467,6 +461,12 @@ class BaseLayerTest(test.TestCase):
self.assertEqual(set(['label', 'logits']), set(result.keys()))
self.assertEqual(3, result['label'].numpy())
self.assertEqual(4.0, result['logits'].numpy())
+ else:
+ i1 = array_ops.placeholder('int32')
+ i2 = array_ops.placeholder('float32')
+ result = layer.apply({'abel': i1, 'ogits': i2})
+ self.assertTrue(isinstance(result, dict))
+ self.assertEqual(set(['label', 'logits']), set(result.keys()))
def testActivityRegularizer(self):
regularizer = math_ops.reduce_sum
diff --git a/tensorflow/python/layers/convolutional.py b/tensorflow/python/layers/convolutional.py
index bb10fe5e8b..74e7c63fb3 100644
--- a/tensorflow/python/layers/convolutional.py
+++ b/tensorflow/python/layers/convolutional.py
@@ -1664,7 +1664,7 @@ class Conv2DTranspose(Conv2D):
padding=self.padding.upper(),
data_format=utils.convert_data_format(self.data_format, ndim=4))
- if context.in_graph_mode():
+ if not context.executing_eagerly():
# Infer the static output shape:
out_shape = inputs.get_shape().as_list()
out_shape[c_axis] = self.filters
@@ -1969,7 +1969,7 @@ class Conv3DTranspose(Conv3D):
data_format=utils.convert_data_format(self.data_format, ndim=5),
padding=self.padding.upper())
- if context.in_graph_mode():
+ if not context.executing_eagerly():
# Infer the static output shape:
out_shape = inputs.get_shape().as_list()
out_shape[c_axis] = self.filters
diff --git a/tensorflow/python/layers/core.py b/tensorflow/python/layers/core.py
index bdbbc59eaf..e598d9f83a 100644
--- a/tensorflow/python/layers/core.py
+++ b/tensorflow/python/layers/core.py
@@ -156,7 +156,7 @@ class Dense(base.Layer):
outputs = standard_ops.tensordot(inputs, self.kernel, [[len(shape) - 1],
[0]])
# Reshape the output back to the original ndim of the input.
- if context.in_graph_mode():
+ if not context.executing_eagerly():
output_shape = shape[:-1] + [self.units]
outputs.set_shape(output_shape)
else:
@@ -374,7 +374,7 @@ class Flatten(base.Layer):
def call(self, inputs):
outputs = array_ops.reshape(inputs, (array_ops.shape(inputs)[0], -1))
- if context.in_graph_mode():
+ if not context.executing_eagerly():
outputs.set_shape(self.compute_output_shape(inputs.get_shape()))
return outputs
diff --git a/tensorflow/python/layers/core_test.py b/tensorflow/python/layers/core_test.py
index 15ce6cba21..ae19866d7a 100644
--- a/tensorflow/python/layers/core_test.py
+++ b/tensorflow/python/layers/core_test.py
@@ -77,7 +77,7 @@ class DenseTest(test.TestCase):
self.assertListEqual(dense.trainable_variables,
[dense.kernel, dense.bias])
self.assertListEqual(dense.non_trainable_variables, [])
- if context.in_graph_mode():
+ if not context.executing_eagerly():
self.assertEqual(
len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 2)
self.assertEqual(dense.kernel.name, 'my_dense/kernel:0')
@@ -98,7 +98,7 @@ class DenseTest(test.TestCase):
self.assertListEqual(dense.variables, [dense.kernel])
self.assertListEqual(dense.trainable_variables, [dense.kernel])
self.assertListEqual(dense.non_trainable_variables, [])
- if context.in_graph_mode():
+ if not context.executing_eagerly():
self.assertEqual(
len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 1)
self.assertEqual(dense.kernel.name, 'my_dense/kernel:0')
@@ -113,7 +113,7 @@ class DenseTest(test.TestCase):
self.assertListEqual(dense.non_trainable_variables,
[dense.kernel, dense.bias])
self.assertListEqual(dense.trainable_variables, [])
- if context.in_graph_mode():
+ if not context.executing_eagerly():
self.assertEqual(
len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 0)
@@ -162,13 +162,13 @@ class DenseTest(test.TestCase):
dense = core_layers.Dense(2, activation=nn_ops.relu, name='dense1')
inputs = random_ops.random_uniform((5, 3), seed=1)
outputs = dense(inputs)
- if context.in_graph_mode():
+ if not context.executing_eagerly():
self.assertEqual(outputs.op.name, 'dense1/Relu')
dense = core_layers.Dense(2, name='dense2')
inputs = random_ops.random_uniform((5, 3), seed=1)
outputs = dense(inputs)
- if context.in_graph_mode():
+ if not context.executing_eagerly():
self.assertEqual(outputs.op.name, 'dense2/BiasAdd')
def testActivityRegularizer(self):
@@ -374,7 +374,7 @@ class DropoutTest(test.TestCase):
dp = core_layers.Dropout(0.5)
inputs = array_ops.ones((5, 3))
dropped = dp.apply(inputs, training=True)
- if context.in_graph_mode():
+ if not context.executing_eagerly():
self.evaluate(variables.global_variables_initializer())
np_output = self.evaluate(dropped)
self.assertAlmostEqual(0., np_output.min())
diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py
index d83292b809..c23d755a8e 100644
--- a/tensorflow/python/layers/normalization.py
+++ b/tensorflow/python/layers/normalization.py
@@ -338,8 +338,9 @@ class BatchNormalization(base.Layer):
return var
with ops.device(None):
- device = ((lambda _: self.moving_mean.device)
- if context.in_graph_mode() else self.moving_mean.device)
+ device = (
+ self.moving_mean.device if context.executing_eagerly() else
+ (lambda _: self.moving_mean.device))
with ops.device(device):
self.renorm_mean = _renorm_variable('renorm_mean', param_shape)
self.renorm_mean_weight = _renorm_variable('renorm_mean_weight', ())
@@ -347,8 +348,9 @@ class BatchNormalization(base.Layer):
# renorm_stddev_weight. This allows us to (1) mix the average
# stddev with the minibatch stddev early in training, and (2) compute
# the unbiased average stddev by dividing renorm_stddev by the weight.
- device = ((lambda _: self.moving_variance.device)
- if context.in_graph_mode() else self.moving_variance.device)
+ device = (
+ self.moving_variance.device if context.executing_eagerly() else
+ (lambda _: self.moving_variance.device))
with ops.device(device):
self.renorm_stddev = _renorm_variable('renorm_stddev', param_shape)
self.renorm_stddev_weight = _renorm_variable(
@@ -420,7 +422,7 @@ class BatchNormalization(base.Layer):
one_minus_decay)
variance_update = self._assign_moving_average(self.moving_variance,
variance, one_minus_decay)
- if context.in_graph_mode():
+ if not context.executing_eagerly():
# Note that in Eager mode, the updates are already executed when running
# assign_moving_averages. So we do not need to put them into
# collections.
@@ -493,7 +495,7 @@ class BatchNormalization(base.Layer):
return (r, d, new_mean, new_variance)
def call(self, inputs, training=False):
- in_eager_mode = context.in_eager_mode()
+ in_eager_mode = context.executing_eagerly()
if self.virtual_batch_size is not None:
# Virtual batches (aka ghost batches) can be simulated by reshaping the
# Tensor and reusing the existing batch norm implementation
@@ -610,7 +612,7 @@ class BatchNormalization(base.Layer):
training,
lambda: _do_update(self.moving_variance, new_variance),
lambda: self.moving_variance)
- if context.in_graph_mode():
+ if not context.executing_eagerly():
self.add_update(mean_update, inputs=inputs)
self.add_update(variance_update, inputs=inputs)