diff options
author | Allen Lavoie <allenl@google.com> | 2018-09-28 09:27:29 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-28 09:33:20 -0700 |
commit | 4eb53d3e5f7bec3c757a06d186ff31fe52083e6d (patch) | |
tree | b3844674c71f21e7a79ec014df9e395a80507400 /tensorflow | |
parent | f4014108a310928cd897085a8bc7d757c641a1c3 (diff) |
Simplify eager/graph Layer.losses conditionals
Fixes an issue where losses created while executing eagerly were returned as unevaluated lambdas in a defun.
Lazily evaluates Layer losses by default when possible. Even when graph building this is generally a better thing to do (e.g. losses called in a while_loop).
Allows calls to Layer.add_loss when executing eagerly, but only for losses which are not conditional on inputs (no activity regularizers).
PiperOrigin-RevId: 214947108
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/python/keras/engine/base_layer.py | 157 | ||||
-rw-r--r-- | tensorflow/python/keras/engine/training_eager_test.py | 14 | ||||
-rw-r--r-- | tensorflow/python/keras/engine/training_test.py | 12 | ||||
-rw-r--r-- | tensorflow/python/layers/base.py | 16 | ||||
-rw-r--r-- | tensorflow/python/layers/convolutional_test.py | 36 | ||||
-rw-r--r-- | tensorflow/python/layers/core_test.py | 6 |
6 files changed, 140 insertions, 101 deletions
diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index e98b131ae6..a75ce30d31 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -20,6 +20,7 @@ from __future__ import print_function import collections as collections_lib import enum # pylint: disable=g-bad-import-order +import functools import inspect # Necessary supplement to tf_inspect to deal with variadic args. import numpy as np @@ -160,9 +161,13 @@ class Layer(checkpointable.CheckpointableBase): self._trainable_weights = [] self._non_trainable_weights = [] self._updates = [] - # When executing eagerly, _losses is a list of zero-argument lambdas which - # return tensors. When using graph execution, _losses is a list of ops. + # A list of zero-argument lambdas which return Tensors, used for variable + # regularizers. + self._callable_losses = [] + # A list of Tensors containing activity regularizers and losses manually + # added through `add_loss`. Empty when executing eagerly. self._losses = [] + self._in_call = False # Flag for error checking in add_loss self._dtype = None if dtype is None else dtypes.as_dtype(dtype).name self._call_fn_args = function_utils.fn_args(self.call) self._compute_previous_mask = ('mask' in self._call_fn_args or @@ -359,20 +364,20 @@ class Layer(checkpointable.CheckpointableBase): def losses(self): """Losses which are associated with this `Layer`. - Note that when executing eagerly, getting this property evaluates - regularizers. When using graph execution, variable regularization ops have - already been created and are simply returned here. + Variable regularization tensors are created when this property is accessed, + so it is eager safe: accessing `losses` under a `tf.GradientTape` will + propagate gradients back to the corresponding variables. Returns: A list of tensors. """ - if context.executing_eagerly(): - # _losses may only contain variable regularization losses when executing - # eagerly, and they have been saved as lambdas to be executed when - # requested. - return [regularizer() for regularizer in self._losses] - else: - return self._losses + collected_losses = [] + collected_losses.extend(self._losses) + for regularizer in self._callable_losses: + loss_tensor = regularizer() + if loss_tensor is not None: + collected_losses.append(loss_tensor) + return collected_losses @doc_controls.for_subclass_implementers def add_loss(self, losses, inputs=None): @@ -393,7 +398,9 @@ class Layer(checkpointable.CheckpointableBase): from `Layer.call()`). Arguments: - losses: Loss tensor, or list/tuple of tensors. + losses: Loss tensor, or list/tuple of tensors. Rather than tensors, losses + may also be zero-argument callables which create a loss tensor. Only + callable losses are supported when executing eagerly. inputs: If anything other than None is passed, it signals the losses are conditional on some of the layer's inputs, and thus they should only be run where these inputs are available. @@ -403,29 +410,45 @@ class Layer(checkpointable.CheckpointableBase): (e.g. weight regularization losses). Raises: - RuntimeError: If called in Eager mode. + RuntimeError: If called in Eager mode with a `Tensor` rather than a + callable, or if `inputs` is not None. """ - if context.executing_eagerly(): - # TODO(fchollet): it should be possible (and highly desirable) to support - # `add_loss` in eager mode. This allows great convenience and flexibility - # in defining custom losses on the fly (e.g. in VAEs). - # Simply appending the loss value to `self._losses` - # is the correct behavior. - # The only caveat is that we need to force the user to only call - # `add_loss` from inside a model or Layer's `call` method - # (otherwise the loss computation cannot be backproped through). - raise RuntimeError('Layer.add_loss not supported in Eager mode.') - + executing_eagerly = context.executing_eagerly() + if executing_eagerly: + if inputs is not None: + raise RuntimeError( + 'Activity regularization (via the "inputs" argument to ' + 'Layer.add_loss) is not supported when executing eagerly. Consider ' + 'returning activity regularization losses from a Model\'s call() ' + 'method.') + if getattr(self, '_in_call', False): + # TODO(psv): Support activity regularization and a way to reset losses. + raise RuntimeError( + 'Adding losses inside a Layer\'s call() method is not currently ' + 'supported when executing eagerly. Please file a feature request ' + 'if you need this limitation lifted.') losses = generic_utils.to_list(losses) - losses = [ops.convert_to_tensor(loss, dtype=backend.floatx()) - if not tensor_util.is_tensor(loss) else loss for loss in losses] - self._losses += losses - if inputs is None: - for loss in losses: - loss._unconditional_loss = True # pylint: disable=protected-access - else: - for loss in losses: - loss._unconditional_loss = False # pylint: disable=protected-access + + def _tag_unconditional(loss): + if callable(loss): + loss = loss() + if loss is None: + return None # Will be filtered out when computing the .losses property + if not tensor_util.is_tensor(loss): + loss = ops.convert_to_tensor(loss, dtype=backend.floatx()) + loss._unconditional_loss = (inputs is None) # pylint: disable=protected-access + return loss + + for loss in losses: + if callable(loss): + self._callable_losses.append( + functools.partial(_tag_unconditional, loss)) + else: + if executing_eagerly: + raise RuntimeError( + 'Layer.add_loss only supported for zero-argument lambdas when ' + 'executing eagerly.') + self._losses.append(_tag_unconditional(loss)) def get_losses_for(self, inputs): """Retrieves losses relevant to a specific set of inputs. @@ -599,56 +622,20 @@ class Layer(checkpointable.CheckpointableBase): return variable def _handle_weight_regularization(self, name, variable, regularizer): - # `init_graph` should point to the graph in which variable initialization - # will occur; it should be None if and only if initialization will take - # place in the eager context. - init_graph = None - if not context.executing_eagerly(): - default_graph = ops.get_default_graph() - if default_graph.building_function: - with ops.init_scope(): - # Retrieve the variables from the graph into which variables - # will be lifted; if initialization ops will be lifted into - # the eager context, then there is nothing to retrieve, since variable - # collections are not supported when eager execution is enabled. - if not context.executing_eagerly(): - init_graph = ops.get_default_graph() - else: - # Initialization ops will not be lifted out of the default graph. - init_graph = default_graph - - if init_graph is not None: # pylint: disable=protected-access - # The variable was created and initialized in a graph. - if regularizer: - if isinstance(variable, tf_variables.PartitionedVariable): - for v in variable: - with ops.colocate_with(v.op): - with ops.name_scope(name + '/Regularizer'): - regularization = regularizer(v) - if regularization is not None: - self.add_loss(regularization) - else: - with ops.colocate_with(variable.op): - with ops.name_scope(name + '/Regularizer'): - regularization = regularizer(variable) - if regularization is not None: - self.add_loss(regularization) - elif regularizer: # initialization took place in an eager context - if isinstance(variable, tf_variables.PartitionedVariable): - raise RuntimeError( - 'Partitioned variable regularization is not yet ' - 'supported when executing eagerly. File a feature request' - 'if this is important to you.') - # Save a zero-argument lambda which runs the regularizer on the - # variable, to be executed when `Layer.losses` is requested. - # This makes losses responsive to variable updates when executing - # eagerly. - # - # TODO(akshayka): Do the same for graphs as well, so that losses - # collected in a while_loop can be run outside its control flow - # context and so that losses won't be swallowed up by graph functions - # (i.e., `.losses()` should always create regularizers). - self._losses.append(lambda: regularizer(variable)) + """Create lambdas which compute regularization losses.""" + + def _loss_for_variable(v): + """Creates a regularization loss `Tensor` for variable `v`.""" + with ops.colocate_with(v): + with ops.name_scope(name + '/Regularizer'): + regularization = regularizer(v) + return regularization + + if isinstance(variable, tf_variables.PartitionedVariable): + for v in variable: + self.add_loss(functools.partial(_loss_for_variable, v)) + else: + self.add_loss(functools.partial(_loss_for_variable, variable)) def _handle_activity_regularization(self, inputs, outputs): # Apply activity regularization. @@ -766,7 +753,9 @@ class Layer(checkpointable.CheckpointableBase): self._assert_input_compatibility(inputs) if not in_deferred_mode: + self._in_call = True outputs = self.call(inputs, *args, **kwargs) + self._in_call = False if outputs is None: raise ValueError('A layer\'s `call` method should return a Tensor ' 'or a list of Tensors, not None (layer: ' + diff --git a/tensorflow/python/keras/engine/training_eager_test.py b/tensorflow/python/keras/engine/training_eager_test.py index db7ccb181f..1f5176c4d7 100644 --- a/tensorflow/python/keras/engine/training_eager_test.py +++ b/tensorflow/python/keras/engine/training_eager_test.py @@ -192,6 +192,20 @@ class CorrectnessTest(test.TestCase): history = model.fit(iterator, epochs=1, steps_per_epoch=10) self.assertEqual(np.around(history.history['loss'][-1], decimals=4), 0.6173) + def test_no_loss_in_call(self): + + class HasLoss(keras.layers.Layer): + + def call(self, x): + self.add_loss(x) + return x + + layer = HasLoss() + with self.assertRaises(RuntimeError): + layer(1.) + + with ops.Graph().as_default(): + layer(1.) if __name__ == '__main__': ops.enable_eager_execution() diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py index 30be4131a4..54ad74c08b 100644 --- a/tensorflow/python/keras/engine/training_test.py +++ b/tensorflow/python/keras/engine/training_test.py @@ -27,6 +27,7 @@ import numpy as np from tensorflow.python import keras from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context +from tensorflow.python.eager import function from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util as tf_test_util @@ -2427,6 +2428,17 @@ class TestTrainingWithMetrics(test.TestCase): scores = model.train_on_batch(x, y, sample_weight=w) self.assertArrayNear(scores, [0.2, 0.8, 0.8], 0.1) + def test_losses_in_defun(self): + with context.eager_mode(): + layer = keras.layers.Dense(1, kernel_regularizer='l1') + layer(array_ops.ones([1, 10])) + + @function.defun + def get_losses(): + return layer.losses + + self.assertAllEqual(self.evaluate(layer.losses), + self.evaluate(get_losses())) if __name__ == '__main__': test.main() diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py index 3ba880d7a1..e399ece232 100644 --- a/tensorflow/python/layers/base.py +++ b/tensorflow/python/layers/base.py @@ -131,10 +131,20 @@ class Layer(base_layer.Layer): def add_loss(self, losses, inputs=None): previous_losses_length = len(self._losses) + previous_callable_losses_length = len(self._callable_losses) super(Layer, self).add_loss(losses, inputs=inputs) - # TODO(fchollet): deprecate collection below. - new_losses = self._losses[previous_losses_length:] - _add_elements_to_collection(new_losses, ops.GraphKeys.REGULARIZATION_LOSSES) + if not context.executing_eagerly(): + # TODO(fchollet): deprecate collection below. + new_losses = self._losses[previous_losses_length:] + new_callable_losses = self._callable_losses[ + previous_callable_losses_length:] + for regularizer in new_callable_losses: + loss_tensor = regularizer() + if loss_tensor is not None: + new_losses.append(loss_tensor) + _add_elements_to_collection( + new_losses, + ops.GraphKeys.REGULARIZATION_LOSSES) def _name_scope(self): """Determines op naming for the Layer.""" diff --git a/tensorflow/python/layers/convolutional_test.py b/tensorflow/python/layers/convolutional_test.py index d61d3b6dba..257fa27156 100644 --- a/tensorflow/python/layers/convolutional_test.py +++ b/tensorflow/python/layers/convolutional_test.py @@ -207,7 +207,8 @@ class ConvTest(test.TestCase): layer.apply(images) loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) self.assertEqual(len(loss_keys), 1) - self.assertListEqual(layer.losses, loss_keys) + self.evaluate([v.initializer for v in layer.variables]) + self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys)) def testConv2DBiasRegularizer(self): height, width = 7, 9 @@ -217,7 +218,8 @@ class ConvTest(test.TestCase): layer.apply(images) loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) self.assertEqual(len(loss_keys), 1) - self.assertListEqual(layer.losses, loss_keys) + self.evaluate([v.initializer for v in layer.variables]) + self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys)) def testConv2DNoBias(self): height, width = 7, 9 @@ -445,7 +447,8 @@ class SeparableConv1DTest(test.TestCase): layer.apply(data) loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) self.assertEqual(len(loss_keys), 1) - self.assertEqual(layer.losses, loss_keys) + self.evaluate([v.initializer for v in layer.variables]) + self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys)) def testSeparableConv1DPointwiseRegularizer(self): length = 9 @@ -455,7 +458,8 @@ class SeparableConv1DTest(test.TestCase): layer.apply(data) loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) self.assertEqual(len(loss_keys), 1) - self.assertEqual(layer.losses, loss_keys) + self.evaluate([v.initializer for v in layer.variables]) + self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys)) def testSeparableConv1DBiasRegularizer(self): length = 9 @@ -465,7 +469,8 @@ class SeparableConv1DTest(test.TestCase): layer.apply(data) loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) self.assertEqual(len(loss_keys), 1) - self.assertEqual(layer.losses, loss_keys) + self.evaluate([v.initializer for v in layer.variables]) + self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys)) def testSeparableConv1DNoBias(self): length = 9 @@ -682,7 +687,8 @@ class SeparableConv2DTest(test.TestCase): layer.apply(images) loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) self.assertEqual(len(loss_keys), 1) - self.assertListEqual(layer.losses, loss_keys) + self.evaluate([v.initializer for v in layer.variables]) + self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys)) def testSeparableConv2DPointwiseRegularizer(self): height, width = 7, 9 @@ -692,7 +698,8 @@ class SeparableConv2DTest(test.TestCase): layer.apply(images) loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) self.assertEqual(len(loss_keys), 1) - self.assertListEqual(layer.losses, loss_keys) + self.evaluate([v.initializer for v in layer.variables]) + self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys)) def testSeparableConv2DBiasRegularizer(self): height, width = 7, 9 @@ -702,7 +709,8 @@ class SeparableConv2DTest(test.TestCase): layer.apply(images) loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) self.assertEqual(len(loss_keys), 1) - self.assertListEqual(layer.losses, loss_keys) + self.evaluate([v.initializer for v in layer.variables]) + self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys)) def testSeparableConv2DNoBias(self): height, width = 7, 9 @@ -839,7 +847,8 @@ class Conv2DTransposeTest(test.TestCase): layer.apply(images) loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) self.assertEqual(len(loss_keys), 1) - self.assertListEqual(layer.losses, loss_keys) + self.evaluate([v.initializer for v in layer.variables]) + self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys)) def testConv2DTransposeBiasRegularizer(self): height, width = 7, 9 @@ -849,7 +858,8 @@ class Conv2DTransposeTest(test.TestCase): layer.apply(images) loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) self.assertEqual(len(loss_keys), 1) - self.assertListEqual(layer.losses, loss_keys) + self.evaluate([v.initializer for v in layer.variables]) + self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys)) def testConv2DTransposeNoBias(self): height, width = 7, 9 @@ -1017,7 +1027,8 @@ class Conv3DTransposeTest(test.TestCase): layer.apply(volumes) loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) self.assertEqual(len(loss_keys), 1) - self.assertListEqual(layer.losses, loss_keys) + self.evaluate([v.initializer for v in layer.variables]) + self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys)) def testConv3DTransposeBiasRegularizer(self): depth, height, width = 5, 7, 9 @@ -1027,7 +1038,8 @@ class Conv3DTransposeTest(test.TestCase): layer.apply(volumes) loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) self.assertEqual(len(loss_keys), 1) - self.assertListEqual(layer.losses, loss_keys) + self.evaluate([v.initializer for v in layer.variables]) + self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys)) def testConv3DTransposeNoBias(self): depth, height, width = 5, 7, 9 diff --git a/tensorflow/python/layers/core_test.py b/tensorflow/python/layers/core_test.py index 46009a30ac..d26f3f4789 100644 --- a/tensorflow/python/layers/core_test.py +++ b/tensorflow/python/layers/core_test.py @@ -197,7 +197,8 @@ class DenseTest(test.TestCase): _ = dense(inputs) loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) self.assertEqual(len(loss_keys), 1) - self.assertListEqual(dense.losses, loss_keys) + self.evaluate([v.initializer for v in dense.variables]) + self.assertAllEqual(self.evaluate(dense.losses), self.evaluate(loss_keys)) def testKernelRegularizerWithReuse(self): regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3 @@ -218,7 +219,8 @@ class DenseTest(test.TestCase): _ = dense(inputs) loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) self.assertEqual(len(loss_keys), 1) - self.assertListEqual(dense.losses, loss_keys) + self.evaluate([v.initializer for v in dense.variables]) + self.assertAllEqual(self.evaluate(dense.losses), self.evaluate(loss_keys)) def testFunctionalDense(self): with self.cached_session(): |