diff options
-rw-r--r-- | tensorflow/python/kernel_tests/sparse_xent_op_test.py | 61 | ||||
-rw-r--r-- | tensorflow/python/ops/nn_ops.py | 63 |
2 files changed, 93 insertions, 31 deletions
diff --git a/tensorflow/python/kernel_tests/sparse_xent_op_test.py b/tensorflow/python/kernel_tests/sparse_xent_op_test.py index a8050cb08d..eb6bdff8b5 100644 --- a/tensorflow/python/kernel_tests/sparse_xent_op_test.py +++ b/tensorflow/python/kernel_tests/sparse_xent_op_test.py @@ -30,6 +30,9 @@ from tensorflow.python.ops import sparse_ops class SparseXentTest(tf.test.TestCase): def _npXent(self, features, labels): + is_higher_dim = len(features.shape) > 2 + features = np.reshape(features, [-1, features.shape[-1]]) + labels = np.reshape(labels, [-1]) batch_dim = 0 class_dim = 1 batch_size = features.shape[batch_dim] @@ -40,14 +43,15 @@ class SparseXentTest(tf.test.TestCase): labels_mat[np.arange(batch_size), labels] = 1.0 bp = (probs - labels_mat) l = -np.sum(labels_mat * np.log(probs + 1.0e-20), axis=1) - return l, bp + return l, bp, is_higher_dim def _testXent(self, np_features, np_labels, use_gpu=False): - np_loss, np_backprop = self._npXent(np_features, np_labels) + np_loss, np_backprop, is_higher_dim = self._npXent(np_features, np_labels) with self.test_session(use_gpu=use_gpu) as sess: loss = tf.nn.sparse_softmax_cross_entropy_with_logits( np_features, np_labels) - backprop = loss.op.outputs[1] + backprop = (loss.op.inputs[0].op.outputs[1] if is_higher_dim + else loss.op.outputs[1]) tf_loss, tf_backprop = sess.run([loss, backprop]) self.assertAllCloseAccordingToType(np_loss, tf_loss) self.assertAllCloseAccordingToType(np_backprop, tf_backprop) @@ -71,14 +75,6 @@ class SparseXentTest(tf.test.TestCase): self._testSingleClass(use_gpu=True) self._testSingleClass(use_gpu=False) - def testRankTooLarge(self): - np_features = np.array( - [[[1., 1., 1., 1.]], [[1., 2., 3., 4.]]]).astype(np.float32) - np_labels = np.array([1, 2]) - self.assertRaisesRegexp( - ValueError, "must have rank 2", - tf.nn.sparse_softmax_cross_entropy_with_logits, np_features, np_labels) - def testNpXent(self): # We create 2 batches of logits for testing. # batch 0 is the boring uniform distribution: 1, 1, 1, 1, with target 3. @@ -104,7 +100,7 @@ class SparseXentTest(tf.test.TestCase): # With a hard 1, the backprop is [0.032 - 1.0 = -0.968, 0.087, 0.237, 0.644] # The loss for this batch is [1.0 * -log(0.25), 1.0 * -log(0.032)] # = [1.3862, 3.4420] - np_loss, np_backprop = self._npXent(np.array(features), np.array(labels)) + np_loss, np_backprop, _ = self._npXent(np.array(features), np.array(labels)) self.assertAllClose(np.array([[0.25, 0.25, 0.25, -0.75], [-0.968, 0.087, 0.237, 0.6439]]), np_backprop, @@ -114,15 +110,21 @@ class SparseXentTest(tf.test.TestCase): def testShapeMismatch(self): with self.test_session(): - with self.assertRaises(ValueError): + with self.assertRaisesRegexp(ValueError, ".*Rank mismatch:*"): tf.nn.sparse_softmax_cross_entropy_with_logits( - [[0., 1.], [2., 3.]], [[0, 2]]) + [[0., 1.], [2., 3.], [2., 3.]], [[0, 2]]) - def testNotMatrix(self): + def testScalar(self): with self.test_session(): - with self.assertRaises(ValueError): + with self.assertRaisesRegexp(ValueError, ".*Logits cannot be scalars*"): tf.nn.sparse_softmax_cross_entropy_with_logits( - [0., 1., 2., 3.], [0, 2]) + tf.constant(1.0), tf.constant(0)) + + def testVector(self): + with self.test_session(): + loss = tf.nn.sparse_softmax_cross_entropy_with_logits( + tf.constant([1.0]), tf.constant(0)) + self.assertAllClose(0.0, loss.eval()) def testFloat(self): for label_dtype in np.int32, np.int64: @@ -155,6 +157,31 @@ class SparseXentTest(tf.test.TestCase): print("cross entropy gradient err = ", err) self.assertLess(err, 5e-8) + def _testHighDim(self, use_gpu, features, labels): + np_loss, np_backprop, _ = self._npXent(np.array(features), np.array(labels)) + # manually reshape loss + np_loss = np.reshape(np_loss, np.array(labels).shape) + with self.test_session(use_gpu=use_gpu) as sess: + loss = tf.nn.sparse_softmax_cross_entropy_with_logits( + features, labels) + backprop = loss.op.inputs[0].op.outputs[1] + tf_loss, tf_backprop = sess.run([loss, backprop]) + self.assertAllCloseAccordingToType(np_loss, tf_loss) + self.assertAllCloseAccordingToType(np_backprop, tf_backprop) + + def testHighDim(self): + features = [[[1., 1., 1., 1.]], [[1., 2., 3., 4.]]] + labels = [[3], [0]] + self._testHighDim(True, features, labels) + self._testHighDim(False, features, labels) + + def testHighDim2(self): + features = [[[1., 1., 1., 1.], [2., 2., 2., 2.]], + [[1., 2., 3., 4.], [5., 6., 7., 8.]]] + labels = [[3, 2], [0, 3]] + self._testHighDim(True, features, labels) + self._testHighDim(False, features, labels) + def _sparse_vs_dense_xent_benchmark_dense(labels, logits): labels = tf.identity(labels) diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 8fb81a813a..baaa6391e9 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -440,30 +440,65 @@ def sparse_softmax_cross_entropy_with_logits(logits, labels, name=None): on `logits` internally for efficiency. Do not call this op with the output of `softmax`, as it will produce incorrect results. - `logits` must have the shape `[batch_size, num_classes]` - and dtype `float32` or `float64`. - - `labels` must have the shape `[batch_size]` and dtype `int32` or `int64`. + A common use case is to have logits of shape `[batch_size, num_classes]` and + labels of shape `[batch_size]`. But higher dimensions are supported. Args: - logits: Unscaled log probabilities. - labels: Each entry `labels[i]` must be an index in `[0, num_classes)`. Other - values will result in a loss of 0, but incorrect gradient computations. + logits: Unscaled log probabilities of rank `r` and shape + `[d_0, d_1, ..., d_{r-2}, num_classes]` and dtype `float32` or `float64`. + labels: `Tensor` of shape `[d_0, d_1, ..., d_{r-2}]` and dtype `int32` or + `int64`. Each entry in `labels` must be an index in `[0, num_classes)`. + Other values will result in a loss of 0, but incorrect gradient + computations. name: A name for the operation (optional). Returns: - A 1-D `Tensor` of length `batch_size` of the same type as `logits` with the - softmax cross entropy loss. + A `Tensor` of the same shape as `labels` and of the same type as `logits` + with the softmax cross entropy loss. + + Raises: + ValueError: If logits are scalars (need to have rank >= 1) or if the rank + of the labels is not equal to the rank of the labels minus one. """ # TODO(pcmurray) Raise an error when the label is not an index in # [0, num_classes). Note: This could break users who call this with bad # labels, but disregard the bad results. - # The second output tensor contains the gradients. We use it in - # _CrossEntropyGrad() in nn_grad but not here. - cost, unused_backprop = gen_nn_ops._sparse_softmax_cross_entropy_with_logits( - logits, labels, name=name) - return cost + # Reshape logits and labels to rank 2. + with ops.op_scope([labels, logits], name, + "SparseSoftmaxCrossEntropyWithLogits"): + labels = ops.convert_to_tensor(labels) + logits = ops.convert_to_tensor(logits) + + # Store label shape for result later. + labels_static_shape = labels.get_shape() + labels_shape = array_ops.shape(labels) + if logits.get_shape().ndims is not None and logits.get_shape().ndims == 0: + raise ValueError("Logits cannot be scalars - received shape %s.", + logits.get_shape()) + if logits.get_shape().ndims is not None and ( + labels_static_shape.ndims is not None and + labels_static_shape.ndims != logits.get_shape().ndims - 1): + raise ValueError("Rank mismatch: Labels rank (received %s) should equal " + "logits rank (received %s) - 1.", + labels_static_shape.ndims, logits.get_shape().ndims) + # Check if no reshapes are required. + if logits.get_shape().ndims == 2: + cost, _ = gen_nn_ops._sparse_softmax_cross_entropy_with_logits( + logits, labels, name=name) + return cost + # Reshape logits to 2 dim, labels to 1 dim. + num_classes = array_ops.gather(array_ops.shape(logits), + array_ops.rank(logits) - 1) + logits = array_ops.reshape(logits, [-1, num_classes]) + labels = array_ops.reshape(labels, [-1]) + # The second output tensor contains the gradients. We use it in + # _CrossEntropyGrad() in nn_grad but not here. + cost, _ = gen_nn_ops._sparse_softmax_cross_entropy_with_logits( + logits, labels, name=name) + cost = array_ops.reshape(cost, labels_shape) + cost.set_shape(labels_static_shape) + return cost @ops.RegisterShape("SparseSoftmaxCrossEntropyWithLogits") |