aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/layers
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-09-28 09:27:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-28 09:33:20 -0700
commit4eb53d3e5f7bec3c757a06d186ff31fe52083e6d (patch)
treeb3844674c71f21e7a79ec014df9e395a80507400 /tensorflow/python/layers
parentf4014108a310928cd897085a8bc7d757c641a1c3 (diff)
Simplify eager/graph Layer.losses conditionals
Fixes an issue where losses created while executing eagerly were returned as unevaluated lambdas in a defun. Lazily evaluates Layer losses by default when possible. Even when graph building this is generally a better thing to do (e.g. losses called in a while_loop). Allows calls to Layer.add_loss when executing eagerly, but only for losses which are not conditional on inputs (no activity regularizers). PiperOrigin-RevId: 214947108
Diffstat (limited to 'tensorflow/python/layers')
-rw-r--r--tensorflow/python/layers/base.py16
-rw-r--r--tensorflow/python/layers/convolutional_test.py36
-rw-r--r--tensorflow/python/layers/core_test.py6
3 files changed, 41 insertions, 17 deletions
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index 3ba880d7a1..e399ece232 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -131,10 +131,20 @@ class Layer(base_layer.Layer):
def add_loss(self, losses, inputs=None):
previous_losses_length = len(self._losses)
+ previous_callable_losses_length = len(self._callable_losses)
super(Layer, self).add_loss(losses, inputs=inputs)
- # TODO(fchollet): deprecate collection below.
- new_losses = self._losses[previous_losses_length:]
- _add_elements_to_collection(new_losses, ops.GraphKeys.REGULARIZATION_LOSSES)
+ if not context.executing_eagerly():
+ # TODO(fchollet): deprecate collection below.
+ new_losses = self._losses[previous_losses_length:]
+ new_callable_losses = self._callable_losses[
+ previous_callable_losses_length:]
+ for regularizer in new_callable_losses:
+ loss_tensor = regularizer()
+ if loss_tensor is not None:
+ new_losses.append(loss_tensor)
+ _add_elements_to_collection(
+ new_losses,
+ ops.GraphKeys.REGULARIZATION_LOSSES)
def _name_scope(self):
"""Determines op naming for the Layer."""
diff --git a/tensorflow/python/layers/convolutional_test.py b/tensorflow/python/layers/convolutional_test.py
index d61d3b6dba..257fa27156 100644
--- a/tensorflow/python/layers/convolutional_test.py
+++ b/tensorflow/python/layers/convolutional_test.py
@@ -207,7 +207,8 @@ class ConvTest(test.TestCase):
layer.apply(images)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testConv2DBiasRegularizer(self):
height, width = 7, 9
@@ -217,7 +218,8 @@ class ConvTest(test.TestCase):
layer.apply(images)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testConv2DNoBias(self):
height, width = 7, 9
@@ -445,7 +447,8 @@ class SeparableConv1DTest(test.TestCase):
layer.apply(data)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testSeparableConv1DPointwiseRegularizer(self):
length = 9
@@ -455,7 +458,8 @@ class SeparableConv1DTest(test.TestCase):
layer.apply(data)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testSeparableConv1DBiasRegularizer(self):
length = 9
@@ -465,7 +469,8 @@ class SeparableConv1DTest(test.TestCase):
layer.apply(data)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testSeparableConv1DNoBias(self):
length = 9
@@ -682,7 +687,8 @@ class SeparableConv2DTest(test.TestCase):
layer.apply(images)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testSeparableConv2DPointwiseRegularizer(self):
height, width = 7, 9
@@ -692,7 +698,8 @@ class SeparableConv2DTest(test.TestCase):
layer.apply(images)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testSeparableConv2DBiasRegularizer(self):
height, width = 7, 9
@@ -702,7 +709,8 @@ class SeparableConv2DTest(test.TestCase):
layer.apply(images)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testSeparableConv2DNoBias(self):
height, width = 7, 9
@@ -839,7 +847,8 @@ class Conv2DTransposeTest(test.TestCase):
layer.apply(images)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testConv2DTransposeBiasRegularizer(self):
height, width = 7, 9
@@ -849,7 +858,8 @@ class Conv2DTransposeTest(test.TestCase):
layer.apply(images)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testConv2DTransposeNoBias(self):
height, width = 7, 9
@@ -1017,7 +1027,8 @@ class Conv3DTransposeTest(test.TestCase):
layer.apply(volumes)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testConv3DTransposeBiasRegularizer(self):
depth, height, width = 5, 7, 9
@@ -1027,7 +1038,8 @@ class Conv3DTransposeTest(test.TestCase):
layer.apply(volumes)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testConv3DTransposeNoBias(self):
depth, height, width = 5, 7, 9
diff --git a/tensorflow/python/layers/core_test.py b/tensorflow/python/layers/core_test.py
index 46009a30ac..d26f3f4789 100644
--- a/tensorflow/python/layers/core_test.py
+++ b/tensorflow/python/layers/core_test.py
@@ -197,7 +197,8 @@ class DenseTest(test.TestCase):
_ = dense(inputs)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(dense.losses, loss_keys)
+ self.evaluate([v.initializer for v in dense.variables])
+ self.assertAllEqual(self.evaluate(dense.losses), self.evaluate(loss_keys))
def testKernelRegularizerWithReuse(self):
regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3
@@ -218,7 +219,8 @@ class DenseTest(test.TestCase):
_ = dense(inputs)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(dense.losses, loss_keys)
+ self.evaluate([v.initializer for v in dense.variables])
+ self.assertAllEqual(self.evaluate(dense.losses), self.evaluate(loss_keys))
def testFunctionalDense(self):
with self.cached_session():