diff options
Diffstat (limited to 'tensorflow/python/ops/nn_grad.py')
-rw-r--r-- | tensorflow/python/ops/nn_grad.py | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py index 4b406ba840..557f39fb42 100644 --- a/tensorflow/python/ops/nn_grad.py +++ b/tensorflow/python/ops/nn_grad.py @@ -420,6 +420,7 @@ def _SoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad): # grad_loss is the backprop for cost, and we multiply it with the gradients # (which is output[1]) # grad_grad is the backprop for softmax gradient. + # There is no gradient for the labels # # Second derivative is just softmax derivative w.r.t. logits. softmax_grad = op.outputs[1] @@ -435,15 +436,15 @@ def _SoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad): const_fill_value = tensor_util.constant_value(g) return const_fill_value is not None and (const_fill_value == 0).all() - logits = op.inputs[0] if grad_grad is not None and not IsZero(grad_grad): + logits = op.inputs[0] softmax = nn_ops.softmax(logits) grad += ((grad_grad - array_ops.squeeze( math_ops.matmul(grad_grad[:, None, :], softmax[:, :, None]), axis=1)) * softmax) - return grad, _BroadcastMul(grad_loss, -nn_ops.log_softmax(logits)) + return grad, None @ops.RegisterGradient("SparseSoftmaxCrossEntropyWithLogits") |