aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/losses
diff options
context:
space:
mode:
authorGravatar Nathan Silberman <nsilberman@google.com>2017-03-13 09:13:14 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-13 10:25:53 -0700
commit7160a7e030572664c7ded4e7a412601a3666bf2b (patch)
tree835da0122e99747ae2c98d1a66d7f2002906f150 /tensorflow/contrib/losses
parentc1af2f81dea1bd2ec9814f63cfdaf016872e7e29 (diff)
Adding test for associative property of pairwise squared loss.
Change: 149955752
Diffstat (limited to 'tensorflow/contrib/losses')
-rw-r--r--tensorflow/contrib/losses/python/losses/loss_ops_test.py51
1 files changed, 46 insertions, 5 deletions
diff --git a/tensorflow/contrib/losses/python/losses/loss_ops_test.py b/tensorflow/contrib/losses/python/losses/loss_ops_test.py
index 81a4aaba2b..9d0f95e6f3 100644
--- a/tensorflow/contrib/losses/python/losses/loss_ops_test.py
+++ b/tensorflow/contrib/losses/python/losses/loss_ops_test.py
@@ -30,6 +30,7 @@ from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -265,7 +266,8 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
[1, 0, 0],
[0, 1, 0]])
weights = [[2.3], [2.4], [2.5]]
- weights_placeholder = array_ops.placeholder(dtypes.float32, shape=[None, None])
+ weights_placeholder = array_ops.placeholder(
+ dtypes.float32, shape=[None, None])
loss = loss_ops.softmax_cross_entropy(logits, labels, weights_placeholder)
with self.test_session() as sess:
loss = sess.run(loss, {weights_placeholder: weights})
@@ -479,8 +481,10 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([2, 0, 1])
weights = [2.3, 2.4, 2.5]
- weights_placeholder = array_ops.placeholder(dtypes.float32, shape=[None])
- loss = loss_ops.sparse_softmax_cross_entropy(logits, labels, weights_placeholder)
+ weights_placeholder = array_ops.placeholder(
+ dtypes.float32, shape=[None])
+ loss = loss_ops.sparse_softmax_cross_entropy(
+ logits, labels, weights_placeholder)
with self.test_session() as sess:
loss = sess.run(loss, {weights_placeholder: weights})
self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3)
@@ -491,8 +495,10 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([2, 0, 1])
weights = [[2.3], [2.4], [2.5]]
- weights_placeholder = array_ops.placeholder(dtypes.float32, shape=[None, None])
- loss = loss_ops.sparse_softmax_cross_entropy(logits, labels, weights_placeholder)
+ weights_placeholder = array_ops.placeholder(
+ dtypes.float32, shape=[None, None])
+ loss = loss_ops.sparse_softmax_cross_entropy(
+ logits, labels, weights_placeholder)
with self.test_session() as sess:
loss = sess.run(loss, {weights_placeholder: weights})
self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3)
@@ -1054,6 +1060,41 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
with self.test_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
+ def testLossIsAssociativeAcrossBatchElements(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(0)
+
+ height = 3
+ width = 4
+ shape = (1, height, width, 1)
+
+ labels0 = random_ops.random_uniform(
+ shape, minval=0, maxval=1, dtype=dtypes.float32)
+ predictions0 = random_ops.random_uniform(
+ shape, minval=0, maxval=1, dtype=dtypes.float32)
+
+ labels1 = random_ops.random_uniform(
+ shape, minval=0, maxval=1, dtype=dtypes.float32)
+ predictions1 = random_ops.random_uniform(
+ shape, minval=0, maxval=1, dtype=dtypes.float32)
+
+ loss0 = loss_ops.mean_pairwise_squared_error(
+ predictions=predictions0,
+ labels=labels0)
+ loss1 = loss_ops.mean_pairwise_squared_error(
+ predictions=predictions1,
+ labels=labels1)
+ loss0_1 = loss_ops.mean_pairwise_squared_error(
+ predictions=array_ops.concat([predictions0, predictions1], 0),
+ labels=array_ops.concat([labels0, labels1], 0))
+
+ with self.test_session() as session:
+ loss0, loss1, loss0_1 = session.run([loss0, loss1, loss0_1])
+
+ self.assertTrue(loss0 > 0)
+ self.assertTrue(loss1 > 0)
+ self.assertAlmostEqual(loss0 + loss1, loss0_1, 5)
+
class CosineDistanceLossTest(test.TestCase):