diff options
Diffstat (limited to 'tensorflow/python/ops/nn_grad.py')
-rw-r--r-- | tensorflow/python/ops/nn_grad.py | 11 |
1 files changed, 6 insertions, 5 deletions
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py index f5e9550b97..9b765390b3 100644 --- a/tensorflow/python/ops/nn_grad.py +++ b/tensorflow/python/ops/nn_grad.py @@ -21,10 +21,10 @@ from __future__ import print_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import sparse_ops -from tensorflow.python.ops import gen_nn_ops @ops.RegisterGradient("Conv2DBackpropInput") @@ -330,7 +330,7 @@ def _EluGradGrad(op, grad): return (gen_nn_ops._elu_grad(grad, op.outputs[0]), array_ops.where( x < 0., gen_nn_ops._elu_grad(grad, op.outputs[0] + 1), - array_ops.zeros(shape = array_ops.shape(x), dtype = x.dtype))) + array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype))) @ops.RegisterGradient("Relu6") @@ -387,12 +387,13 @@ def _SoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad): softmax_grad = op.outputs[1] grad = _BroadcastMul(grad_loss, softmax_grad) - if grad_grad.op.type not in ('ZerosLike', 'Zeros'): + if grad_grad.op.type not in ("ZerosLike", "Zeros"): logits = op.inputs[0] softmax = nn_ops.softmax(logits) - grad += ((grad_grad - array_ops.squeeze(math_ops.matmul(grad_grad[:, None, :], - softmax[:, :, None]), axis=1)) * softmax) + grad += ((grad_grad - array_ops.squeeze( + math_ops.matmul(grad_grad[:, None, :], + softmax[:, :, None]), axis=1)) * softmax) return grad, None |