diff options
Diffstat (limited to 'tensorflow/core/kernels/relu_op.h')
-rw-r--r-- | tensorflow/core/kernels/relu_op.h | 42 |
1 files changed, 42 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/relu_op.h b/tensorflow/core/kernels/relu_op.h index 365c6201a5..e712b02bd7 100644 --- a/tensorflow/core/kernels/relu_op.h +++ b/tensorflow/core/kernels/relu_op.h @@ -173,6 +173,48 @@ void EluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context, output->flat<T>()); } +template <typename Device, typename T> +class SeluOp : public UnaryElementWiseOp<T, SeluOp<Device, T>> { + public: + using UnaryElementWiseOp<T, SeluOp<Device, T>>::UnaryElementWiseOp; + + void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { + functor::Selu<Device, T> functor; + functor(context->eigen_device<Device>(), input.flat<T>(), + output->flat<T>()); + } +}; + +template <typename Device, typename T> +class SeluGradOp : public BinaryElementWiseOp<T, SeluGradOp<Device, T>> { + public: + using BinaryElementWiseOp<T, SeluGradOp<Device, T>>::BinaryElementWiseOp; + + void OperateNoTemplate(OpKernelContext* context, const Tensor& g, + const Tensor& a, Tensor* output); + + // INPUTS: + // g (gradients): backpropagated gradients + // a (outputs): outputs of the SeluOp() + // OUTPUT: + // gradients to backprop + template <int NDIMS> + void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a, + Tensor* output) { + OperateNoTemplate(context, g, a, output); + } +}; + +template <typename Device, typename T> +void SeluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context, + const Tensor& g, const Tensor& a, + Tensor* output) { + if (!ReluHelpers::ValidateSameSize(context, g, a)) return; + functor::SeluGrad<Device, T> functor; + functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(), + output->flat<T>()); +} + } // namespace tensorflow #undef EIGEN_USE_THREADS |