diff options
Diffstat (limited to 'tensorflow/python/ops/nn_grad.py')
-rw-r--r-- | tensorflow/python/ops/nn_grad.py | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py index e1a01ab4c3..902653befc 100644 --- a/tensorflow/python/ops/nn_grad.py +++ b/tensorflow/python/ops/nn_grad.py @@ -389,6 +389,21 @@ def _Relu6GradGrad(op, grad): array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype)) +@ops.RegisterGradient("LeakyRelu") +def _LeakyReluGrad(op, grad): + x = op.inputs[0] + alpha = op.get_attr("alpha") + return gen_nn_ops.leaky_relu_grad(grad, x, alpha=alpha) + + +@ops.RegisterGradient("LeakyReluGrad") +def _LeakyReluGradGrad(op, grad): + x = op.inputs[1] + alpha = op.get_attr("alpha") + return (gen_nn_ops.leaky_relu_grad(grad, x, alpha=alpha), + array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype)) + + @ops.RegisterGradient("Elu") def _EluGrad(op, grad): return gen_nn_ops.elu_grad(grad, op.outputs[0]) |