aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/learn
diff options
context:
space:
mode:
authorGravatar Yong Tang <yong.tang.github@outlook.com>2018-07-03 12:33:58 +0000
committerGravatar Yong Tang <yong.tang.github@outlook.com>2018-07-03 12:34:14 +0000
commita77a9689198675f62ced41eb5c737eec429b8fae (patch)
tree38bb5b015c18d5d115eb4c9e65a369f23a142596 /tensorflow/contrib/learn
parent00071753077dcd9f1486c1335f05eed80e68efcb (diff)
Fix warning in _log_loss_with_two_classes as well
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Diffstat (limited to 'tensorflow/contrib/learn')
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/head.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py
index e9c79f88b0..ded93d4a7f 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/head.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/head.py
@@ -797,7 +797,7 @@ def _log_loss_with_two_classes(labels, logits, weights=None):
# TODO(ptucker): This will break for dynamic shapes.
# sigmoid_cross_entropy_with_logits requires [batch_size, 1] labels.
if len(labels.get_shape()) == 1:
- labels = array_ops.expand_dims(labels, dim=(1,))
+ labels = array_ops.expand_dims(labels, axis=(1,))
loss = nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits,
name=name)
return _compute_weighted_loss(loss, weights)