aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/losses
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2017-02-08 09:25:09 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-08 09:50:05 -0800
commit639b4e71f532761a4840b1cdbaea55ad0917c75b (patch)
tree5116415b1d9ff82f054dd4feeadd81cb833d6435 /tensorflow/contrib/losses
parent15ff7b702788c0cf75bb8d5ce090f06490098cf7 (diff)
Merge changes from github.
Change: 146918929
Diffstat (limited to 'tensorflow/contrib/losses')
-rw-r--r--tensorflow/contrib/losses/python/losses/loss_ops.py1
-rw-r--r--tensorflow/contrib/losses/python/losses/loss_ops_test.py52
2 files changed, 52 insertions, 1 deletions
diff --git a/tensorflow/contrib/losses/python/losses/loss_ops.py b/tensorflow/contrib/losses/python/losses/loss_ops.py
index 1e4fb58945..5ca8c8a18b 100644
--- a/tensorflow/contrib/losses/python/losses/loss_ops.py
+++ b/tensorflow/contrib/losses/python/losses/loss_ops.py
@@ -427,7 +427,6 @@ def sparse_softmax_cross_entropy(logits, labels, weights=1.0, scope=None):
with ops.name_scope(scope, "sparse_softmax_cross_entropy_loss",
[logits, labels, weights]) as scope:
labels = array_ops.reshape(labels, shape=[array_ops.shape(labels)[0]])
- weights = array_ops.squeeze(weights)
losses = nn.sparse_softmax_cross_entropy_with_logits(labels=labels,
logits=logits,
diff --git a/tensorflow/contrib/losses/python/losses/loss_ops_test.py b/tensorflow/contrib/losses/python/losses/loss_ops_test.py
index 94b8dfca57..81a4aaba2b 100644
--- a/tensorflow/contrib/losses/python/losses/loss_ops_test.py
+++ b/tensorflow/contrib/losses/python/losses/loss_ops_test.py
@@ -243,6 +243,34 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
expected_value = 400.0 * label_smoothing / 3.0
self.assertAlmostEqual(loss.eval(), expected_value, 3)
+ def testLossWithDynamicallyShapedWeights1D(self):
+ logits = constant_op.constant([[10.0, 0.0, 0.0],
+ [0.0, 10.0, 0.0],
+ [0.0, 0.0, 10.0]])
+ labels = constant_op.constant([[0, 0, 1],
+ [1, 0, 0],
+ [0, 1, 0]])
+ weights = [2.3, 2.4, 2.5]
+ weights_placeholder = array_ops.placeholder(dtypes.float32, shape=[None])
+ loss = loss_ops.softmax_cross_entropy(logits, labels, weights_placeholder)
+ with self.test_session() as sess:
+ loss = sess.run(loss, {weights_placeholder: weights})
+ self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3)
+
+ def testLossWithDynamicallyShapedWeights2D(self):
+ logits = constant_op.constant([[10.0, 0.0, 0.0],
+ [0.0, 10.0, 0.0],
+ [0.0, 0.0, 10.0]])
+ labels = constant_op.constant([[0, 0, 1],
+ [1, 0, 0],
+ [0, 1, 0]])
+ weights = [[2.3], [2.4], [2.5]]
+ weights_placeholder = array_ops.placeholder(dtypes.float32, shape=[None, None])
+ loss = loss_ops.softmax_cross_entropy(logits, labels, weights_placeholder)
+ with self.test_session() as sess:
+ loss = sess.run(loss, {weights_placeholder: weights})
+ self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3)
+
class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
@@ -445,6 +473,30 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
loss_ops.sparse_softmax_cross_entropy(
logits, labels, weights=weights).eval()
+ def testLossWithDynamicallyShapedWeights1D(self):
+ logits = constant_op.constant([[10.0, 0.0, 0.0],
+ [0.0, 10.0, 0.0],
+ [0.0, 0.0, 10.0]])
+ labels = constant_op.constant([2, 0, 1])
+ weights = [2.3, 2.4, 2.5]
+ weights_placeholder = array_ops.placeholder(dtypes.float32, shape=[None])
+ loss = loss_ops.sparse_softmax_cross_entropy(logits, labels, weights_placeholder)
+ with self.test_session() as sess:
+ loss = sess.run(loss, {weights_placeholder: weights})
+ self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3)
+
+ def testLossWithDynamicallyShapedWeights2D(self):
+ logits = constant_op.constant([[10.0, 0.0, 0.0],
+ [0.0, 10.0, 0.0],
+ [0.0, 0.0, 10.0]])
+ labels = constant_op.constant([2, 0, 1])
+ weights = [[2.3], [2.4], [2.5]]
+ weights_placeholder = array_ops.placeholder(dtypes.float32, shape=[None, None])
+ loss = loss_ops.sparse_softmax_cross_entropy(logits, labels, weights_placeholder)
+ with self.test_session() as sess:
+ loss = sess.run(loss, {weights_placeholder: weights})
+ self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3)
+
class SigmoidCrossEntropyLossTest(test.TestCase):