diff options
author | 2016-11-11 15:24:06 -0800 | |
---|---|---|
committer | 2016-11-11 15:45:55 -0800 | |
commit | 84a4cbe5cd2a65cb60ccc65eac8c00caf4e98aed (patch) | |
tree | 99f8d8b5d57caee5a1cac1879036f5dbef9b65cf /tensorflow/contrib/losses | |
parent | 7e3608db285f5b92443fcaf77f233cda825ce3ae (diff) |
Change loss_ops_test to pass in data with batch_size 1.
Change: 138925660
Diffstat (limited to 'tensorflow/contrib/losses')
-rw-r--r-- | tensorflow/contrib/losses/python/losses/loss_ops_test.py | 29 |
1 files changed, 21 insertions, 8 deletions
diff --git a/tensorflow/contrib/losses/python/losses/loss_ops_test.py b/tensorflow/contrib/losses/python/losses/loss_ops_test.py index e08752c8d2..75785f00a1 100644 --- a/tensorflow/contrib/losses/python/losses/loss_ops_test.py +++ b/tensorflow/contrib/losses/python/losses/loss_ops_test.py @@ -1350,30 +1350,43 @@ class ComputeWeightedLossTest(tf.test.TestCase): class AddLossTest(tf.test.TestCase): def testAddExternalLoss(self): - logits = tf.constant([1.2, 0.4, -1.0, -1.1]) - labels = tf.constant([1.0, 0.0, 0.0, 1.0]) + logits = tf.constant([[1.2, 0.4, -1.0, -1.1]]) + labels = tf.constant([[1.0, 0.0, 0.0, 1.0]]) losses = tf.contrib.losses.hinge_loss(logits, labels) self.assertFalse(tf.contrib.losses.get_losses()) tf.contrib.losses.add_loss(tf.reduce_mean(losses)) self.assertTrue(tf.contrib.losses.get_losses()) total_loss = tf.contrib.losses.get_total_loss() with self.test_session(): - self.assertAllClose(losses.eval(), [0.0, 1.4, 0.0, 2.1], atol=1e-3) + self.assertAllClose(losses.eval(), [[0.0, 1.4, 0.0, 2.1]], atol=1e-3) self.assertAllClose(total_loss.eval(), 3.5/4.0, atol=1e-3) def testNoneLossCollection(self): - logits = tf.constant([1.2, 0.4, -1.0, -1.1]) - labels = tf.constant([1.0, 0.0, 0.0, 1.0]) + logits = tf.constant([[1.2, 0.4, -1.0, -1.1]]) + labels = tf.constant([[1.0, 0.0, 0.0, 1.0]]) losses = tf.contrib.losses.hinge_loss(logits, labels) self.assertFalse(tf.contrib.losses.get_losses()) tf.contrib.losses.add_loss(tf.reduce_mean(losses), loss_collection=None) self.assertFalse(tf.contrib.losses.get_losses()) with self.test_session(): - self.assertAllClose(losses.eval(), [0.0, 1.4, 0.0, 2.1], atol=1e-3) + self.assertAllClose(losses.eval(), [[0.0, 1.4, 0.0, 2.1]], atol=1e-3) def testNoCollectLosses(self): - logits = tf.constant([1.2, 0.4, -1.0, -1.1]) - labels = tf.constant([1.0, 0.0, 0.0, 1.0]) + logits = tf.constant([[1.2, 0.4, -1.0, -1.1]]) + labels = tf.constant([[1.0, 0.0, 0.0, 1.0]]) + self.assertFalse(tf.contrib.losses.get_losses()) + with tf.contrib.framework.arg_scope([tf.contrib.losses.add_loss], + loss_collection=None): + tf.contrib.losses.absolute_difference(logits, labels) + tf.contrib.losses.log_loss(logits, labels) + tf.contrib.losses.mean_squared_error(logits, labels) + tf.contrib.losses.sigmoid_cross_entropy(logits, labels) + tf.contrib.losses.softmax_cross_entropy(logits, labels) + self.assertFalse(tf.contrib.losses.get_losses()) + + def testNoCollectLossesBatch2(self): + logits = tf.constant([[1.2, 0.4, -1.0, -1.1]] * 2) + labels = tf.constant([[1.0, 0.0, 0.0, 1.0]] * 2) self.assertFalse(tf.contrib.losses.get_losses()) with tf.contrib.framework.arg_scope([tf.contrib.losses.add_loss], loss_collection=None): |