aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/losses
diff options
context:
space:
mode:
authorGravatar Martin Wicke <wicke@google.com>2017-01-04 21:25:34 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-04 21:46:08 -0800
commit333dc32ff79af21484695157f3d141dc776f7c02 (patch)
treeb379bcaa56bfa54d12ea839fb7e62ab163490743 /tensorflow/contrib/losses
parentd9541696b068cfcc1fab66b03d0b8d605b64f14d (diff)
Change arg order for {softmax,sparse_softmax,sigmoid}_cross_entropy_with_logits to be (labels, predictions), and force use of named args to avoid accidents.
Change: 143629623
Diffstat (limited to 'tensorflow/contrib/losses')
-rw-r--r--tensorflow/contrib/losses/python/losses/loss_ops.py9
1 files changed, 6 insertions, 3 deletions
diff --git a/tensorflow/contrib/losses/python/losses/loss_ops.py b/tensorflow/contrib/losses/python/losses/loss_ops.py
index ed4469773b..69293bea13 100644
--- a/tensorflow/contrib/losses/python/losses/loss_ops.py
+++ b/tensorflow/contrib/losses/python/losses/loss_ops.py
@@ -340,7 +340,8 @@ def sigmoid_cross_entropy(
multi_class_labels = (multi_class_labels * (1 - label_smoothing) +
0.5 * label_smoothing)
- losses = nn.sigmoid_cross_entropy_with_logits(logits, multi_class_labels,
+ losses = nn.sigmoid_cross_entropy_with_logits(labels=multi_class_labels,
+ logits=logits,
name="xentropy")
return compute_weighted_loss(losses, weights, scope=scope)
@@ -387,7 +388,8 @@ def softmax_cross_entropy(
smooth_negatives = label_smoothing / num_classes
onehot_labels = onehot_labels * smooth_positives + smooth_negatives
- losses = nn.softmax_cross_entropy_with_logits(logits, onehot_labels,
+ losses = nn.softmax_cross_entropy_with_logits(labels=onehot_labels,
+ logits=logits,
name="xentropy")
return compute_weighted_loss(losses, weights, scope=scope)
@@ -421,7 +423,8 @@ def sparse_softmax_cross_entropy(logits, labels, weights=1.0, scope=None):
labels = array_ops.reshape(labels, shape=[array_ops.shape(labels)[0]])
weights = array_ops.squeeze(weights)
- losses = nn.sparse_softmax_cross_entropy_with_logits(logits, labels,
+ losses = nn.sparse_softmax_cross_entropy_with_logits(labels=labels,
+ logits=logits,
name="xentropy")
return compute_weighted_loss(losses, weights, scope=scope)