aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/nn_grad.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/nn_grad.py')
-rw-r--r--tensorflow/python/ops/nn_grad.py11
1 files changed, 6 insertions, 5 deletions
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py
index f5e9550b97..9b765390b3 100644
--- a/tensorflow/python/ops/nn_grad.py
+++ b/tensorflow/python/ops/nn_grad.py
@@ -21,10 +21,10 @@ from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import sparse_ops
-from tensorflow.python.ops import gen_nn_ops
@ops.RegisterGradient("Conv2DBackpropInput")
@@ -330,7 +330,7 @@ def _EluGradGrad(op, grad):
return (gen_nn_ops._elu_grad(grad, op.outputs[0]),
array_ops.where(
x < 0., gen_nn_ops._elu_grad(grad, op.outputs[0] + 1),
- array_ops.zeros(shape = array_ops.shape(x), dtype = x.dtype)))
+ array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype)))
@ops.RegisterGradient("Relu6")
@@ -387,12 +387,13 @@ def _SoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad):
softmax_grad = op.outputs[1]
grad = _BroadcastMul(grad_loss, softmax_grad)
- if grad_grad.op.type not in ('ZerosLike', 'Zeros'):
+ if grad_grad.op.type not in ("ZerosLike", "Zeros"):
logits = op.inputs[0]
softmax = nn_ops.softmax(logits)
- grad += ((grad_grad - array_ops.squeeze(math_ops.matmul(grad_grad[:, None, :],
- softmax[:, :, None]), axis=1)) * softmax)
+ grad += ((grad_grad - array_ops.squeeze(
+ math_ops.matmul(grad_grad[:, None, :],
+ softmax[:, :, None]), axis=1)) * softmax)
return grad, None