aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/losses
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-11-11 15:24:06 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-11 15:45:55 -0800
commit84a4cbe5cd2a65cb60ccc65eac8c00caf4e98aed (patch)
tree99f8d8b5d57caee5a1cac1879036f5dbef9b65cf /tensorflow/contrib/losses
parent7e3608db285f5b92443fcaf77f233cda825ce3ae (diff)
Change loss_ops_test to pass in data with batch_size 1.
Change: 138925660
Diffstat (limited to 'tensorflow/contrib/losses')
-rw-r--r--tensorflow/contrib/losses/python/losses/loss_ops_test.py29
1 files changed, 21 insertions, 8 deletions
diff --git a/tensorflow/contrib/losses/python/losses/loss_ops_test.py b/tensorflow/contrib/losses/python/losses/loss_ops_test.py
index e08752c8d2..75785f00a1 100644
--- a/tensorflow/contrib/losses/python/losses/loss_ops_test.py
+++ b/tensorflow/contrib/losses/python/losses/loss_ops_test.py
@@ -1350,30 +1350,43 @@ class ComputeWeightedLossTest(tf.test.TestCase):
class AddLossTest(tf.test.TestCase):
def testAddExternalLoss(self):
- logits = tf.constant([1.2, 0.4, -1.0, -1.1])
- labels = tf.constant([1.0, 0.0, 0.0, 1.0])
+ logits = tf.constant([[1.2, 0.4, -1.0, -1.1]])
+ labels = tf.constant([[1.0, 0.0, 0.0, 1.0]])
losses = tf.contrib.losses.hinge_loss(logits, labels)
self.assertFalse(tf.contrib.losses.get_losses())
tf.contrib.losses.add_loss(tf.reduce_mean(losses))
self.assertTrue(tf.contrib.losses.get_losses())
total_loss = tf.contrib.losses.get_total_loss()
with self.test_session():
- self.assertAllClose(losses.eval(), [0.0, 1.4, 0.0, 2.1], atol=1e-3)
+ self.assertAllClose(losses.eval(), [[0.0, 1.4, 0.0, 2.1]], atol=1e-3)
self.assertAllClose(total_loss.eval(), 3.5/4.0, atol=1e-3)
def testNoneLossCollection(self):
- logits = tf.constant([1.2, 0.4, -1.0, -1.1])
- labels = tf.constant([1.0, 0.0, 0.0, 1.0])
+ logits = tf.constant([[1.2, 0.4, -1.0, -1.1]])
+ labels = tf.constant([[1.0, 0.0, 0.0, 1.0]])
losses = tf.contrib.losses.hinge_loss(logits, labels)
self.assertFalse(tf.contrib.losses.get_losses())
tf.contrib.losses.add_loss(tf.reduce_mean(losses), loss_collection=None)
self.assertFalse(tf.contrib.losses.get_losses())
with self.test_session():
- self.assertAllClose(losses.eval(), [0.0, 1.4, 0.0, 2.1], atol=1e-3)
+ self.assertAllClose(losses.eval(), [[0.0, 1.4, 0.0, 2.1]], atol=1e-3)
def testNoCollectLosses(self):
- logits = tf.constant([1.2, 0.4, -1.0, -1.1])
- labels = tf.constant([1.0, 0.0, 0.0, 1.0])
+ logits = tf.constant([[1.2, 0.4, -1.0, -1.1]])
+ labels = tf.constant([[1.0, 0.0, 0.0, 1.0]])
+ self.assertFalse(tf.contrib.losses.get_losses())
+ with tf.contrib.framework.arg_scope([tf.contrib.losses.add_loss],
+ loss_collection=None):
+ tf.contrib.losses.absolute_difference(logits, labels)
+ tf.contrib.losses.log_loss(logits, labels)
+ tf.contrib.losses.mean_squared_error(logits, labels)
+ tf.contrib.losses.sigmoid_cross_entropy(logits, labels)
+ tf.contrib.losses.softmax_cross_entropy(logits, labels)
+ self.assertFalse(tf.contrib.losses.get_losses())
+
+ def testNoCollectLossesBatch2(self):
+ logits = tf.constant([[1.2, 0.4, -1.0, -1.1]] * 2)
+ labels = tf.constant([[1.0, 0.0, 0.0, 1.0]] * 2)
self.assertFalse(tf.contrib.losses.get_losses())
with tf.contrib.framework.arg_scope([tf.contrib.losses.add_loss],
loss_collection=None):