aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/learn/python/learn/estimators/head.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/learn/python/learn/estimators/head.py')
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/head.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py
index 339c4e0e36..ded93d4a7f 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/head.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/head.py
@@ -563,10 +563,10 @@ def _mean_squared_loss(labels, logits, weights=None):
labels = ops.convert_to_tensor(labels)
# To prevent broadcasting inside "-".
if len(labels.get_shape()) == 1:
- labels = array_ops.expand_dims(labels, dim=(1,))
+ labels = array_ops.expand_dims(labels, axis=(1,))
# TODO(zakaria): make sure it does not recreate the broadcast bug.
if len(logits.get_shape()) == 1:
- logits = array_ops.expand_dims(logits, dim=(1,))
+ logits = array_ops.expand_dims(logits, axis=(1,))
logits.get_shape().assert_is_compatible_with(labels.get_shape())
loss = math_ops.square(logits - math_ops.to_float(labels), name=name)
return _compute_weighted_loss(loss, weights)
@@ -579,10 +579,10 @@ def _poisson_loss(labels, logits, weights=None):
labels = ops.convert_to_tensor(labels)
# To prevent broadcasting inside "-".
if len(labels.get_shape()) == 1:
- labels = array_ops.expand_dims(labels, dim=(1,))
+ labels = array_ops.expand_dims(labels, axis=(1,))
# TODO(zakaria): make sure it does not recreate the broadcast bug.
if len(logits.get_shape()) == 1:
- logits = array_ops.expand_dims(logits, dim=(1,))
+ logits = array_ops.expand_dims(logits, axis=(1,))
logits.get_shape().assert_is_compatible_with(labels.get_shape())
loss = nn.log_poisson_loss(labels, logits, compute_full_loss=True,
name=name)
@@ -797,7 +797,7 @@ def _log_loss_with_two_classes(labels, logits, weights=None):
# TODO(ptucker): This will break for dynamic shapes.
# sigmoid_cross_entropy_with_logits requires [batch_size, 1] labels.
if len(labels.get_shape()) == 1:
- labels = array_ops.expand_dims(labels, dim=(1,))
+ labels = array_ops.expand_dims(labels, axis=(1,))
loss = nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits,
name=name)
return _compute_weighted_loss(loss, weights)