diff options
author | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-19 17:43:33 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-19 17:43:54 -0700 |
commit | 90ec1ab915907ae94ac1d212611b460a68d8c98f (patch) | |
tree | fcdf75c1a31daee12f4ae6d5f1e8eedf5a1fef28 /tensorflow/contrib/learn | |
parent | 37bbf89920f013ef1d59f0eaef65431d4f4a4a28 (diff) | |
parent | a77a9689198675f62ced41eb5c737eec429b8fae (diff) |
Merge pull request #20520 from yongtang:07012018-expand_dims
PiperOrigin-RevId: 205331312
Diffstat (limited to 'tensorflow/contrib/learn')
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/head.py | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index 339c4e0e36..ded93d4a7f 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -563,10 +563,10 @@ def _mean_squared_loss(labels, logits, weights=None): labels = ops.convert_to_tensor(labels) # To prevent broadcasting inside "-". if len(labels.get_shape()) == 1: - labels = array_ops.expand_dims(labels, dim=(1,)) + labels = array_ops.expand_dims(labels, axis=(1,)) # TODO(zakaria): make sure it does not recreate the broadcast bug. if len(logits.get_shape()) == 1: - logits = array_ops.expand_dims(logits, dim=(1,)) + logits = array_ops.expand_dims(logits, axis=(1,)) logits.get_shape().assert_is_compatible_with(labels.get_shape()) loss = math_ops.square(logits - math_ops.to_float(labels), name=name) return _compute_weighted_loss(loss, weights) @@ -579,10 +579,10 @@ def _poisson_loss(labels, logits, weights=None): labels = ops.convert_to_tensor(labels) # To prevent broadcasting inside "-". if len(labels.get_shape()) == 1: - labels = array_ops.expand_dims(labels, dim=(1,)) + labels = array_ops.expand_dims(labels, axis=(1,)) # TODO(zakaria): make sure it does not recreate the broadcast bug. if len(logits.get_shape()) == 1: - logits = array_ops.expand_dims(logits, dim=(1,)) + logits = array_ops.expand_dims(logits, axis=(1,)) logits.get_shape().assert_is_compatible_with(labels.get_shape()) loss = nn.log_poisson_loss(labels, logits, compute_full_loss=True, name=name) @@ -797,7 +797,7 @@ def _log_loss_with_two_classes(labels, logits, weights=None): # TODO(ptucker): This will break for dynamic shapes. # sigmoid_cross_entropy_with_logits requires [batch_size, 1] labels. if len(labels.get_shape()) == 1: - labels = array_ops.expand_dims(labels, dim=(1,)) + labels = array_ops.expand_dims(labels, axis=(1,)) loss = nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits, name=name) return _compute_weighted_loss(loss, weights) |