diff options
Diffstat (limited to 'tensorflow/contrib/learn/python/learn/estimators/head.py')
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/head.py | 11 |
1 files changed, 10 insertions, 1 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index 452f8a901e..15e457f932 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -921,12 +921,21 @@ def _softmax_cross_entropy_loss(labels, logits, weights=None): if not labels.dtype.is_integer: raise ValueError("Labels dtype should be integer " "Instead got %s." % labels.dtype) - # TODO(ptucker): This will break for dynamic shapes. + # sparse_softmax_cross_entropy_with_logits requires [batch_size] labels. + is_squeezed_labels = False + # TODO(ptucker): This will break for dynamic shapes. if len(labels.get_shape()) == 2: labels = array_ops.squeeze(labels, squeeze_dims=(1,)) + is_squeezed_labels = True + loss = nn.sparse_softmax_cross_entropy_with_logits( labels=labels, logits=logits, name=name) + + # Restore squeezed dimension, if necessary, so loss matches weights shape. + if is_squeezed_labels: + loss = array_ops.expand_dims(loss, axis=(1,)) + return _compute_weighted_loss(loss, weights) |