diff options
author | Nathan Silberman <nsilberman@google.com> | 2017-03-13 09:13:14 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-13 10:25:53 -0700 |
commit | 7160a7e030572664c7ded4e7a412601a3666bf2b (patch) | |
tree | 835da0122e99747ae2c98d1a66d7f2002906f150 /tensorflow/contrib/losses | |
parent | c1af2f81dea1bd2ec9814f63cfdaf016872e7e29 (diff) |
Adding test for associative property of pairwise squared loss.
Change: 149955752
Diffstat (limited to 'tensorflow/contrib/losses')
-rw-r--r-- | tensorflow/contrib/losses/python/losses/loss_ops_test.py | 51 |
1 files changed, 46 insertions, 5 deletions
diff --git a/tensorflow/contrib/losses/python/losses/loss_ops_test.py b/tensorflow/contrib/losses/python/losses/loss_ops_test.py index 81a4aaba2b..9d0f95e6f3 100644 --- a/tensorflow/contrib/losses/python/losses/loss_ops_test.py +++ b/tensorflow/contrib/losses/python/losses/loss_ops_test.py @@ -30,6 +30,7 @@ from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -265,7 +266,8 @@ class SoftmaxCrossEntropyLossTest(test.TestCase): [1, 0, 0], [0, 1, 0]]) weights = [[2.3], [2.4], [2.5]] - weights_placeholder = array_ops.placeholder(dtypes.float32, shape=[None, None]) + weights_placeholder = array_ops.placeholder( + dtypes.float32, shape=[None, None]) loss = loss_ops.softmax_cross_entropy(logits, labels, weights_placeholder) with self.test_session() as sess: loss = sess.run(loss, {weights_placeholder: weights}) @@ -479,8 +481,10 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase): [0.0, 0.0, 10.0]]) labels = constant_op.constant([2, 0, 1]) weights = [2.3, 2.4, 2.5] - weights_placeholder = array_ops.placeholder(dtypes.float32, shape=[None]) - loss = loss_ops.sparse_softmax_cross_entropy(logits, labels, weights_placeholder) + weights_placeholder = array_ops.placeholder( + dtypes.float32, shape=[None]) + loss = loss_ops.sparse_softmax_cross_entropy( + logits, labels, weights_placeholder) with self.test_session() as sess: loss = sess.run(loss, {weights_placeholder: weights}) self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3) @@ -491,8 +495,10 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase): [0.0, 0.0, 10.0]]) labels = constant_op.constant([2, 0, 1]) weights = [[2.3], [2.4], [2.5]] - weights_placeholder = array_ops.placeholder(dtypes.float32, shape=[None, None]) - loss = loss_ops.sparse_softmax_cross_entropy(logits, labels, weights_placeholder) + weights_placeholder = array_ops.placeholder( + dtypes.float32, shape=[None, None]) + loss = loss_ops.sparse_softmax_cross_entropy( + logits, labels, weights_placeholder) with self.test_session() as sess: loss = sess.run(loss, {weights_placeholder: weights}) self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3) @@ -1054,6 +1060,41 @@ class MeanPairwiseSquaresErrorTest(test.TestCase): with self.test_session(): self.assertAlmostEqual(0.0, loss.eval(), 3) + def testLossIsAssociativeAcrossBatchElements(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(0) + + height = 3 + width = 4 + shape = (1, height, width, 1) + + labels0 = random_ops.random_uniform( + shape, minval=0, maxval=1, dtype=dtypes.float32) + predictions0 = random_ops.random_uniform( + shape, minval=0, maxval=1, dtype=dtypes.float32) + + labels1 = random_ops.random_uniform( + shape, minval=0, maxval=1, dtype=dtypes.float32) + predictions1 = random_ops.random_uniform( + shape, minval=0, maxval=1, dtype=dtypes.float32) + + loss0 = loss_ops.mean_pairwise_squared_error( + predictions=predictions0, + labels=labels0) + loss1 = loss_ops.mean_pairwise_squared_error( + predictions=predictions1, + labels=labels1) + loss0_1 = loss_ops.mean_pairwise_squared_error( + predictions=array_ops.concat([predictions0, predictions1], 0), + labels=array_ops.concat([labels0, labels1], 0)) + + with self.test_session() as session: + loss0, loss1, loss0_1 = session.run([loss0, loss1, loss0_1]) + + self.assertTrue(loss0 > 0) + self.assertTrue(loss1 > 0) + self.assertAlmostEqual(loss0 + loss1, loss0_1, 5) + class CosineDistanceLossTest(test.TestCase): |