aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/nn_grad.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/nn_grad.py')
-rw-r--r--tensorflow/python/ops/nn_grad.py15
1 files changed, 15 insertions, 0 deletions
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py
index e1a01ab4c3..902653befc 100644
--- a/tensorflow/python/ops/nn_grad.py
+++ b/tensorflow/python/ops/nn_grad.py
@@ -389,6 +389,21 @@ def _Relu6GradGrad(op, grad):
array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype))
+@ops.RegisterGradient("LeakyRelu")
+def _LeakyReluGrad(op, grad):
+ x = op.inputs[0]
+ alpha = op.get_attr("alpha")
+ return gen_nn_ops.leaky_relu_grad(grad, x, alpha=alpha)
+
+
+@ops.RegisterGradient("LeakyReluGrad")
+def _LeakyReluGradGrad(op, grad):
+ x = op.inputs[1]
+ alpha = op.get_attr("alpha")
+ return (gen_nn_ops.leaky_relu_grad(grad, x, alpha=alpha),
+ array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype))
+
+
@ops.RegisterGradient("Elu")
def _EluGrad(op, grad):
return gen_nn_ops.elu_grad(grad, op.outputs[0])