aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/losses/python/losses/loss_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/losses/python/losses/loss_ops.py')
-rw-r--r--tensorflow/contrib/losses/python/losses/loss_ops.py9
1 files changed, 4 insertions, 5 deletions
diff --git a/tensorflow/contrib/losses/python/losses/loss_ops.py b/tensorflow/contrib/losses/python/losses/loss_ops.py
index 8c3a8afe7a..bdad34a665 100644
--- a/tensorflow/contrib/losses/python/losses/loss_ops.py
+++ b/tensorflow/contrib/losses/python/losses/loss_ops.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import nn
from tensorflow.python.ops import nn_ops
from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.deprecation import deprecated_args
+from tensorflow.python.util.deprecation import deprecated_argument_lookup
__all__ = [
"absolute_difference", "add_loss", "cosine_distance",
@@ -651,11 +652,9 @@ def cosine_distance(predictions,
ValueError: If `predictions` shape doesn't match `labels` shape, or
`weights` is `None`.
"""
- if dim is not None:
- if axis is not None:
- raise ValueError("Cannot specify both 'axis' and 'dim'")
- axis = dim
- if axis is None and dim is None:
+ axis = deprecated_argument_lookup(
+ "axis", axis, "dim", dim)
+ if axis is None:
raise ValueError("You must specify 'axis'.")
with ops.name_scope(scope, "cosine_distance_loss",
[predictions, labels, weights]) as scope: