aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/nn_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/nn_ops.py')
-rw-r--r--tensorflow/python/ops/nn_ops.py57
1 files changed, 57 insertions, 0 deletions
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index bdc32360a3..6c0cd339e6 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -157,6 +157,12 @@ def softmax_cross_entropy_with_logits(logits, labels, name=None):
example, each CIFAR-10 image is labeled with one and only one label: an image
can be a dog or a truck, but not both.
+ **NOTE:**: While the classes are mutually exclusive, their probabilities
+ need not be. All that is required is that each row of `labels` is
+ a valid probability distribution. If using exclusive `labels`
+ (wherein one and only one class is true at a time), see
+ `sparse_softmax_cross_entropy_with_logits`.
+
**WARNING:** This op expects unscaled logits, since it performs a `softmax`
on `logits` internally for efficiency. Do not call this op with the
output of `softmax`, as it will produce incorrect results.
@@ -180,6 +186,57 @@ def softmax_cross_entropy_with_logits(logits, labels, name=None):
return cost
+def sparse_softmax_cross_entropy_with_logits(logits, labels, name=None):
+ """Computes sparse softmax cross entropy between `logits` and `labels`.
+
+ Measures the probability error in discrete classification tasks in which the
+ classes are mutually exclusive (each entry is in exactly one class). For
+ example, each CIFAR-10 image is labeled with one and only one label: an image
+ can be a dog or a truck, but not both.
+
+ **NOTE:**: For this operation, the probability of a given label is considered
+ exclusive. That is, soft classes are not allowed, and the `labels` vector
+ must provide a single specific index for the true class for each row of
+ `logits` (each minibatch entry). For soft softmax classification with
+ a probability distribution for each entry, see
+ `softmax_cross_entropy_with_logits`.
+
+ **WARNING:** This op expects unscaled logits, since it performs a `softmax`
+ on `logits` internally for efficiency. Do not call this op with the
+ output of `softmax`, as it will produce incorrect results.
+
+ `logits` and must have the shape `[batch_size, num_classes]`
+ and the dtype (either `float32` or `float64`).
+
+ `labels` must have the shape `[batch_size]` and the dtype `int64`.
+
+ Args:
+ logits: Unscaled log probabilities.
+ labels: Each entry `labels[i]` must be an index in `[0, num_classes)`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A 1-D `Tensor` of length `batch_size` of the same type as `logits` with the
+ softmax cross entropy loss.
+ """
+ # The second output tensor contains the gradients. We use it in
+ # _CrossEntropyGrad() in nn_grad but not here.
+ cost, unused_backprop = gen_nn_ops._sparse_softmax_cross_entropy_with_logits(
+ logits, labels, name=name)
+ return cost
+
+
+@ops.RegisterShape("SparseSoftmaxCrossEntropyWithLogits")
+def _SparseSoftmaxCrossEntropyWithLogitsShape(op):
+ """Shape function for SparseSoftmaxCrossEntropyWithLogits op."""
+ logits_shape = op.inputs[0].get_shape()
+ input_shape = logits_shape.with_rank(2)
+ batch_size = input_shape[0]
+ # labels_shape
+ op.inputs[1].get_shape().merge_with(tensor_shape.vector(batch_size))
+ return [tensor_shape.vector(batch_size.value), input_shape]
+
+
@ops.RegisterShape("SoftmaxCrossEntropyWithLogits")
def _SoftmaxCrossEntropyWithLogitsShape(op):
"""Shape function for SoftmaxCrossEntropyWithLogits op."""