aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/losses/losses_impl.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/losses/losses_impl.py')
-rw-r--r--tensorflow/python/ops/losses/losses_impl.py8
1 files changed, 5 insertions, 3 deletions
diff --git a/tensorflow/python/ops/losses/losses_impl.py b/tensorflow/python/ops/losses/losses_impl.py
index 5222333d7e..ca408988dd 100644
--- a/tensorflow/python/ops/losses/losses_impl.py
+++ b/tensorflow/python/ops/losses/losses_impl.py
@@ -726,9 +726,11 @@ def softmax_cross_entropy(
smooth_negatives = label_smoothing / num_classes
onehot_labels = onehot_labels * smooth_positives + smooth_negatives
- losses = nn.softmax_cross_entropy_with_logits(labels=onehot_labels,
- logits=logits,
- name="xentropy")
+ onehot_labels = array_ops.stop_gradient(
+ onehot_labels, name="labels_stop_gradient")
+ losses = nn.softmax_cross_entropy_with_logits_v2(
+ labels=onehot_labels, logits=logits, name="xentropy")
+
return compute_weighted_loss(
losses, weights, scope, loss_collection, reduction=reduction)