diff options
Diffstat (limited to 'tensorflow/core/kernels/relu_op.h')
-rw-r--r-- | tensorflow/core/kernels/relu_op.h | 61 |
1 files changed, 61 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/relu_op.h b/tensorflow/core/kernels/relu_op.h index 4775deeb61..a4638c70c2 100644 --- a/tensorflow/core/kernels/relu_op.h +++ b/tensorflow/core/kernels/relu_op.h @@ -132,6 +132,67 @@ void Relu6GradOp<Device, T>::OperateNoTemplate(OpKernelContext* context, } template <typename Device, typename T> +class LeakyReluOp : public UnaryElementWiseOp<T, LeakyReluOp<Device, T>> { + public: + explicit LeakyReluOp(OpKernelConstruction* context) + : UnaryElementWiseOp<T, LeakyReluOp<Device, T>>(context) { + float alpha_tmp; + OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_tmp)); + alpha_ = T(alpha_tmp); + } + + void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { + functor::LeakyRelu<Device, T> functor; + functor(context->eigen_device<Device>(), input.flat<T>(), alpha_, + output->flat<T>()); + } + + private: + T alpha_; +}; + +template <typename Device, typename T> +class LeakyReluGradOp + : public BinaryElementWiseOp<T, LeakyReluGradOp<Device, T>> { + public: + explicit LeakyReluGradOp(OpKernelConstruction* context) + : BinaryElementWiseOp<T, LeakyReluGradOp<Device, T>>(context) { + float alpha_tmp; + OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_tmp)); + alpha_ = T(alpha_tmp); + } + + void OperateNoTemplate(OpKernelContext* context, const Tensor& g, + const Tensor& a, T alpha, Tensor* output); + + // INPUTS: + // g (gradients): backpropagated gradients + // a (inputs): either the inputs that were passed to LeakyReluOp(), or its + // outputs (using either one yields the same result here). + // OUTPUT: + // gradients to backprop + template <int NDIMS> + void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a, + Tensor* output) { + OperateNoTemplate(context, g, a, alpha_, output); + } + + private: + T alpha_; +}; + +template <typename Device, typename T> +void LeakyReluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context, + const Tensor& g, + const Tensor& a, T alpha, + Tensor* output) { + if (!ReluHelpers::ValidateSameSize(context, g, a)) return; + functor::LeakyReluGrad<Device, T> functor; + functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(), alpha, + output->flat<T>()); +}; + +template <typename Device, typename T> class EluOp : public UnaryElementWiseOp<T, EluOp<Device, T>> { public: using UnaryElementWiseOp<T, EluOp<Device, T>>::UnaryElementWiseOp; |