aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/relu_op.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/relu_op.h')
-rw-r--r--tensorflow/core/kernels/relu_op.h42
1 files changed, 42 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/relu_op.h b/tensorflow/core/kernels/relu_op.h
index 365c6201a5..e712b02bd7 100644
--- a/tensorflow/core/kernels/relu_op.h
+++ b/tensorflow/core/kernels/relu_op.h
@@ -173,6 +173,48 @@ void EluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
output->flat<T>());
}
+template <typename Device, typename T>
+class SeluOp : public UnaryElementWiseOp<T, SeluOp<Device, T>> {
+ public:
+ using UnaryElementWiseOp<T, SeluOp<Device, T>>::UnaryElementWiseOp;
+
+ void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
+ functor::Selu<Device, T> functor;
+ functor(context->eigen_device<Device>(), input.flat<T>(),
+ output->flat<T>());
+ }
+};
+
+template <typename Device, typename T>
+class SeluGradOp : public BinaryElementWiseOp<T, SeluGradOp<Device, T>> {
+ public:
+ using BinaryElementWiseOp<T, SeluGradOp<Device, T>>::BinaryElementWiseOp;
+
+ void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
+ const Tensor& a, Tensor* output);
+
+ // INPUTS:
+ // g (gradients): backpropagated gradients
+ // a (outputs): outputs of the SeluOp()
+ // OUTPUT:
+ // gradients to backprop
+ template <int NDIMS>
+ void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
+ Tensor* output) {
+ OperateNoTemplate(context, g, a, output);
+ }
+};
+
+template <typename Device, typename T>
+void SeluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
+ const Tensor& g, const Tensor& a,
+ Tensor* output) {
+ if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
+ functor::SeluGrad<Device, T> functor;
+ functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
+ output->flat<T>());
+}
+
} // namespace tensorflow
#undef EIGEN_USE_THREADS