aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/layers/core_test.py
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-09-28 09:27:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-28 09:33:20 -0700
commit4eb53d3e5f7bec3c757a06d186ff31fe52083e6d (patch)
treeb3844674c71f21e7a79ec014df9e395a80507400 /tensorflow/python/layers/core_test.py
parentf4014108a310928cd897085a8bc7d757c641a1c3 (diff)
Simplify eager/graph Layer.losses conditionals
Fixes an issue where losses created while executing eagerly were returned as unevaluated lambdas in a defun. Lazily evaluates Layer losses by default when possible. Even when graph building this is generally a better thing to do (e.g. losses called in a while_loop). Allows calls to Layer.add_loss when executing eagerly, but only for losses which are not conditional on inputs (no activity regularizers). PiperOrigin-RevId: 214947108
Diffstat (limited to 'tensorflow/python/layers/core_test.py')
-rw-r--r--tensorflow/python/layers/core_test.py6
1 files changed, 4 insertions, 2 deletions
diff --git a/tensorflow/python/layers/core_test.py b/tensorflow/python/layers/core_test.py
index 46009a30ac..d26f3f4789 100644
--- a/tensorflow/python/layers/core_test.py
+++ b/tensorflow/python/layers/core_test.py
@@ -197,7 +197,8 @@ class DenseTest(test.TestCase):
_ = dense(inputs)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(dense.losses, loss_keys)
+ self.evaluate([v.initializer for v in dense.variables])
+ self.assertAllEqual(self.evaluate(dense.losses), self.evaluate(loss_keys))
def testKernelRegularizerWithReuse(self):
regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3
@@ -218,7 +219,8 @@ class DenseTest(test.TestCase):
_ = dense(inputs)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(dense.losses, loss_keys)
+ self.evaluate([v.initializer for v in dense.variables])
+ self.assertAllEqual(self.evaluate(dense.losses), self.evaluate(loss_keys))
def testFunctionalDense(self):
with self.cached_session():