aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/kernel_tests/sparse_xent_op_test.py61
-rw-r--r--tensorflow/python/ops/nn_ops.py63
2 files changed, 93 insertions, 31 deletions
diff --git a/tensorflow/python/kernel_tests/sparse_xent_op_test.py b/tensorflow/python/kernel_tests/sparse_xent_op_test.py
index a8050cb08d..eb6bdff8b5 100644
--- a/tensorflow/python/kernel_tests/sparse_xent_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_xent_op_test.py
@@ -30,6 +30,9 @@ from tensorflow.python.ops import sparse_ops
class SparseXentTest(tf.test.TestCase):
def _npXent(self, features, labels):
+ is_higher_dim = len(features.shape) > 2
+ features = np.reshape(features, [-1, features.shape[-1]])
+ labels = np.reshape(labels, [-1])
batch_dim = 0
class_dim = 1
batch_size = features.shape[batch_dim]
@@ -40,14 +43,15 @@ class SparseXentTest(tf.test.TestCase):
labels_mat[np.arange(batch_size), labels] = 1.0
bp = (probs - labels_mat)
l = -np.sum(labels_mat * np.log(probs + 1.0e-20), axis=1)
- return l, bp
+ return l, bp, is_higher_dim
def _testXent(self, np_features, np_labels, use_gpu=False):
- np_loss, np_backprop = self._npXent(np_features, np_labels)
+ np_loss, np_backprop, is_higher_dim = self._npXent(np_features, np_labels)
with self.test_session(use_gpu=use_gpu) as sess:
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
np_features, np_labels)
- backprop = loss.op.outputs[1]
+ backprop = (loss.op.inputs[0].op.outputs[1] if is_higher_dim
+ else loss.op.outputs[1])
tf_loss, tf_backprop = sess.run([loss, backprop])
self.assertAllCloseAccordingToType(np_loss, tf_loss)
self.assertAllCloseAccordingToType(np_backprop, tf_backprop)
@@ -71,14 +75,6 @@ class SparseXentTest(tf.test.TestCase):
self._testSingleClass(use_gpu=True)
self._testSingleClass(use_gpu=False)
- def testRankTooLarge(self):
- np_features = np.array(
- [[[1., 1., 1., 1.]], [[1., 2., 3., 4.]]]).astype(np.float32)
- np_labels = np.array([1, 2])
- self.assertRaisesRegexp(
- ValueError, "must have rank 2",
- tf.nn.sparse_softmax_cross_entropy_with_logits, np_features, np_labels)
-
def testNpXent(self):
# We create 2 batches of logits for testing.
# batch 0 is the boring uniform distribution: 1, 1, 1, 1, with target 3.
@@ -104,7 +100,7 @@ class SparseXentTest(tf.test.TestCase):
# With a hard 1, the backprop is [0.032 - 1.0 = -0.968, 0.087, 0.237, 0.644]
# The loss for this batch is [1.0 * -log(0.25), 1.0 * -log(0.032)]
# = [1.3862, 3.4420]
- np_loss, np_backprop = self._npXent(np.array(features), np.array(labels))
+ np_loss, np_backprop, _ = self._npXent(np.array(features), np.array(labels))
self.assertAllClose(np.array([[0.25, 0.25, 0.25, -0.75],
[-0.968, 0.087, 0.237, 0.6439]]),
np_backprop,
@@ -114,15 +110,21 @@ class SparseXentTest(tf.test.TestCase):
def testShapeMismatch(self):
with self.test_session():
- with self.assertRaises(ValueError):
+ with self.assertRaisesRegexp(ValueError, ".*Rank mismatch:*"):
tf.nn.sparse_softmax_cross_entropy_with_logits(
- [[0., 1.], [2., 3.]], [[0, 2]])
+ [[0., 1.], [2., 3.], [2., 3.]], [[0, 2]])
- def testNotMatrix(self):
+ def testScalar(self):
with self.test_session():
- with self.assertRaises(ValueError):
+ with self.assertRaisesRegexp(ValueError, ".*Logits cannot be scalars*"):
tf.nn.sparse_softmax_cross_entropy_with_logits(
- [0., 1., 2., 3.], [0, 2])
+ tf.constant(1.0), tf.constant(0))
+
+ def testVector(self):
+ with self.test_session():
+ loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
+ tf.constant([1.0]), tf.constant(0))
+ self.assertAllClose(0.0, loss.eval())
def testFloat(self):
for label_dtype in np.int32, np.int64:
@@ -155,6 +157,31 @@ class SparseXentTest(tf.test.TestCase):
print("cross entropy gradient err = ", err)
self.assertLess(err, 5e-8)
+ def _testHighDim(self, use_gpu, features, labels):
+ np_loss, np_backprop, _ = self._npXent(np.array(features), np.array(labels))
+ # manually reshape loss
+ np_loss = np.reshape(np_loss, np.array(labels).shape)
+ with self.test_session(use_gpu=use_gpu) as sess:
+ loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
+ features, labels)
+ backprop = loss.op.inputs[0].op.outputs[1]
+ tf_loss, tf_backprop = sess.run([loss, backprop])
+ self.assertAllCloseAccordingToType(np_loss, tf_loss)
+ self.assertAllCloseAccordingToType(np_backprop, tf_backprop)
+
+ def testHighDim(self):
+ features = [[[1., 1., 1., 1.]], [[1., 2., 3., 4.]]]
+ labels = [[3], [0]]
+ self._testHighDim(True, features, labels)
+ self._testHighDim(False, features, labels)
+
+ def testHighDim2(self):
+ features = [[[1., 1., 1., 1.], [2., 2., 2., 2.]],
+ [[1., 2., 3., 4.], [5., 6., 7., 8.]]]
+ labels = [[3, 2], [0, 3]]
+ self._testHighDim(True, features, labels)
+ self._testHighDim(False, features, labels)
+
def _sparse_vs_dense_xent_benchmark_dense(labels, logits):
labels = tf.identity(labels)
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index 8fb81a813a..baaa6391e9 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -440,30 +440,65 @@ def sparse_softmax_cross_entropy_with_logits(logits, labels, name=None):
on `logits` internally for efficiency. Do not call this op with the
output of `softmax`, as it will produce incorrect results.
- `logits` must have the shape `[batch_size, num_classes]`
- and dtype `float32` or `float64`.
-
- `labels` must have the shape `[batch_size]` and dtype `int32` or `int64`.
+ A common use case is to have logits of shape `[batch_size, num_classes]` and
+ labels of shape `[batch_size]`. But higher dimensions are supported.
Args:
- logits: Unscaled log probabilities.
- labels: Each entry `labels[i]` must be an index in `[0, num_classes)`. Other
- values will result in a loss of 0, but incorrect gradient computations.
+ logits: Unscaled log probabilities of rank `r` and shape
+ `[d_0, d_1, ..., d_{r-2}, num_classes]` and dtype `float32` or `float64`.
+ labels: `Tensor` of shape `[d_0, d_1, ..., d_{r-2}]` and dtype `int32` or
+ `int64`. Each entry in `labels` must be an index in `[0, num_classes)`.
+ Other values will result in a loss of 0, but incorrect gradient
+ computations.
name: A name for the operation (optional).
Returns:
- A 1-D `Tensor` of length `batch_size` of the same type as `logits` with the
- softmax cross entropy loss.
+ A `Tensor` of the same shape as `labels` and of the same type as `logits`
+ with the softmax cross entropy loss.
+
+ Raises:
+ ValueError: If logits are scalars (need to have rank >= 1) or if the rank
+ of the labels is not equal to the rank of the labels minus one.
"""
# TODO(pcmurray) Raise an error when the label is not an index in
# [0, num_classes). Note: This could break users who call this with bad
# labels, but disregard the bad results.
- # The second output tensor contains the gradients. We use it in
- # _CrossEntropyGrad() in nn_grad but not here.
- cost, unused_backprop = gen_nn_ops._sparse_softmax_cross_entropy_with_logits(
- logits, labels, name=name)
- return cost
+ # Reshape logits and labels to rank 2.
+ with ops.op_scope([labels, logits], name,
+ "SparseSoftmaxCrossEntropyWithLogits"):
+ labels = ops.convert_to_tensor(labels)
+ logits = ops.convert_to_tensor(logits)
+
+ # Store label shape for result later.
+ labels_static_shape = labels.get_shape()
+ labels_shape = array_ops.shape(labels)
+ if logits.get_shape().ndims is not None and logits.get_shape().ndims == 0:
+ raise ValueError("Logits cannot be scalars - received shape %s.",
+ logits.get_shape())
+ if logits.get_shape().ndims is not None and (
+ labels_static_shape.ndims is not None and
+ labels_static_shape.ndims != logits.get_shape().ndims - 1):
+ raise ValueError("Rank mismatch: Labels rank (received %s) should equal "
+ "logits rank (received %s) - 1.",
+ labels_static_shape.ndims, logits.get_shape().ndims)
+ # Check if no reshapes are required.
+ if logits.get_shape().ndims == 2:
+ cost, _ = gen_nn_ops._sparse_softmax_cross_entropy_with_logits(
+ logits, labels, name=name)
+ return cost
+ # Reshape logits to 2 dim, labels to 1 dim.
+ num_classes = array_ops.gather(array_ops.shape(logits),
+ array_ops.rank(logits) - 1)
+ logits = array_ops.reshape(logits, [-1, num_classes])
+ labels = array_ops.reshape(labels, [-1])
+ # The second output tensor contains the gradients. We use it in
+ # _CrossEntropyGrad() in nn_grad but not here.
+ cost, _ = gen_nn_ops._sparse_softmax_cross_entropy_with_logits(
+ logits, labels, name=name)
+ cost = array_ops.reshape(cost, labels_shape)
+ cost.set_shape(labels_static_shape)
+ return cost
@ops.RegisterShape("SparseSoftmaxCrossEntropyWithLogits")