aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/relu_op.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/relu_op.h')
-rw-r--r--tensorflow/core/kernels/relu_op.h61
1 files changed, 61 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/relu_op.h b/tensorflow/core/kernels/relu_op.h
index 4775deeb61..a4638c70c2 100644
--- a/tensorflow/core/kernels/relu_op.h
+++ b/tensorflow/core/kernels/relu_op.h
@@ -132,6 +132,67 @@ void Relu6GradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
}
template <typename Device, typename T>
+class LeakyReluOp : public UnaryElementWiseOp<T, LeakyReluOp<Device, T>> {
+ public:
+ explicit LeakyReluOp(OpKernelConstruction* context)
+ : UnaryElementWiseOp<T, LeakyReluOp<Device, T>>(context) {
+ float alpha_tmp;
+ OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_tmp));
+ alpha_ = T(alpha_tmp);
+ }
+
+ void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
+ functor::LeakyRelu<Device, T> functor;
+ functor(context->eigen_device<Device>(), input.flat<T>(), alpha_,
+ output->flat<T>());
+ }
+
+ private:
+ T alpha_;
+};
+
+template <typename Device, typename T>
+class LeakyReluGradOp
+ : public BinaryElementWiseOp<T, LeakyReluGradOp<Device, T>> {
+ public:
+ explicit LeakyReluGradOp(OpKernelConstruction* context)
+ : BinaryElementWiseOp<T, LeakyReluGradOp<Device, T>>(context) {
+ float alpha_tmp;
+ OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_tmp));
+ alpha_ = T(alpha_tmp);
+ }
+
+ void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
+ const Tensor& a, T alpha, Tensor* output);
+
+ // INPUTS:
+ // g (gradients): backpropagated gradients
+ // a (inputs): either the inputs that were passed to LeakyReluOp(), or its
+ // outputs (using either one yields the same result here).
+ // OUTPUT:
+ // gradients to backprop
+ template <int NDIMS>
+ void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
+ Tensor* output) {
+ OperateNoTemplate(context, g, a, alpha_, output);
+ }
+
+ private:
+ T alpha_;
+};
+
+template <typename Device, typename T>
+void LeakyReluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
+ const Tensor& g,
+ const Tensor& a, T alpha,
+ Tensor* output) {
+ if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
+ functor::LeakyReluGrad<Device, T> functor;
+ functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(), alpha,
+ output->flat<T>());
+};
+
+template <typename Device, typename T>
class EluOp : public UnaryElementWiseOp<T, EluOp<Device, T>> {
public:
using UnaryElementWiseOp<T, EluOp<Device, T>>::UnaryElementWiseOp;