aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/learn/python/learn/estimators/head.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/learn/python/learn/estimators/head.py')
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/head.py11
1 files changed, 10 insertions, 1 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py
index 452f8a901e..15e457f932 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/head.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/head.py
@@ -921,12 +921,21 @@ def _softmax_cross_entropy_loss(labels, logits, weights=None):
if not labels.dtype.is_integer:
raise ValueError("Labels dtype should be integer "
"Instead got %s." % labels.dtype)
- # TODO(ptucker): This will break for dynamic shapes.
+
# sparse_softmax_cross_entropy_with_logits requires [batch_size] labels.
+ is_squeezed_labels = False
+ # TODO(ptucker): This will break for dynamic shapes.
if len(labels.get_shape()) == 2:
labels = array_ops.squeeze(labels, squeeze_dims=(1,))
+ is_squeezed_labels = True
+
loss = nn.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=logits, name=name)
+
+ # Restore squeezed dimension, if necessary, so loss matches weights shape.
+ if is_squeezed_labels:
+ loss = array_ops.expand_dims(loss, axis=(1,))
+
return _compute_weighted_loss(loss, weights)