diff options
Diffstat (limited to 'tensorflow/python/ops/nn_grad.py')
-rw-r--r-- | tensorflow/python/ops/nn_grad.py | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py index f1453f9ef0..50673ed427 100644 --- a/tensorflow/python/ops/nn_grad.py +++ b/tensorflow/python/ops/nn_grad.py @@ -335,6 +335,16 @@ def _EluGradGrad(op, grad): dtype=elu_x.dtype))) +@ops.RegisterGradient("SeluGrad") +def _SeluGradGrad(op, grad): + x = op.inputs[1] + scale_alpha = 1.7580993408473768599402175208123 + return (gen_nn_ops._elu_grad(grad, op.outputs[0]), + array_ops.where( + x < 0., gen_nn_ops._elu_grad(grad, op.outputs[0] + scale_alpha), + array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype))) + + @ops.RegisterGradient("Relu6") def _Relu6Grad(op, grad): return gen_nn_ops._relu6_grad(grad, op.inputs[0]) @@ -345,6 +355,11 @@ def _EluGrad(op, grad): return gen_nn_ops._elu_grad(grad, op.outputs[0]) +@ops.RegisterGradient("Selu") +def _SeluGrad(op, grad): + return gen_nn_ops._selu_grad(grad, op.outputs[0]) + + @ops.RegisterGradient("Softplus") def _SoftplusGrad(op, grad): return gen_nn_ops._softplus_grad(grad, op.inputs[0]) |