diff options
author | Yong Tang <yong.tang.github@outlook.com> | 2018-07-03 12:29:18 +0000 |
---|---|---|
committer | Yong Tang <yong.tang.github@outlook.com> | 2018-07-03 12:29:18 +0000 |
commit | 69e37cef0ca721f76d12a3808521d73299aab7ea (patch) | |
tree | f83ef422ad0221fe3ba0e2ff6cdbdbe269caf647 /tensorflow/contrib/learn | |
parent | 519487fb313a31701fd31e67c0ceb6eae8ea9225 (diff) |
Update calling of expand_dims with axis
This fix updates calling of `expand_dims` with `dim -> axis`
as the `dim=` in `tf.expand_dims` has been deprecated and was
generating unnecessary warnings.
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Diffstat (limited to 'tensorflow/contrib/learn')
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/head.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index 339c4e0e36..dee0755204 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -563,10 +563,10 @@ def _mean_squared_loss(labels, logits, weights=None): labels = ops.convert_to_tensor(labels) # To prevent broadcasting inside "-". if len(labels.get_shape()) == 1: - labels = array_ops.expand_dims(labels, dim=(1,)) + labels = array_ops.expand_dims(labels, axis=(1,)) # TODO(zakaria): make sure it does not recreate the broadcast bug. if len(logits.get_shape()) == 1: - logits = array_ops.expand_dims(logits, dim=(1,)) + logits = array_ops.expand_dims(logits, axis=(1,)) logits.get_shape().assert_is_compatible_with(labels.get_shape()) loss = math_ops.square(logits - math_ops.to_float(labels), name=name) return _compute_weighted_loss(loss, weights) |