aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/nn_grad.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/nn_grad.py')
-rw-r--r--tensorflow/python/ops/nn_grad.py15
1 files changed, 15 insertions, 0 deletions
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py
index f1453f9ef0..50673ed427 100644
--- a/tensorflow/python/ops/nn_grad.py
+++ b/tensorflow/python/ops/nn_grad.py
@@ -335,6 +335,16 @@ def _EluGradGrad(op, grad):
dtype=elu_x.dtype)))
+@ops.RegisterGradient("SeluGrad")
+def _SeluGradGrad(op, grad):
+ x = op.inputs[1]
+ scale_alpha = 1.7580993408473768599402175208123
+ return (gen_nn_ops._elu_grad(grad, op.outputs[0]),
+ array_ops.where(
+ x < 0., gen_nn_ops._elu_grad(grad, op.outputs[0] + scale_alpha),
+ array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype)))
+
+
@ops.RegisterGradient("Relu6")
def _Relu6Grad(op, grad):
return gen_nn_ops._relu6_grad(grad, op.inputs[0])
@@ -345,6 +355,11 @@ def _EluGrad(op, grad):
return gen_nn_ops._elu_grad(grad, op.outputs[0])
+@ops.RegisterGradient("Selu")
+def _SeluGrad(op, grad):
+ return gen_nn_ops._selu_grad(grad, op.outputs[0])
+
+
@ops.RegisterGradient("Softplus")
def _SoftplusGrad(op, grad):
return gen_nn_ops._softplus_grad(grad, op.inputs[0])