diff options
Diffstat (limited to 'tensorflow/core/kernels/relu_op_functor.h')
-rw-r--r-- | tensorflow/core/kernels/relu_op_functor.h | 30 |
1 files changed, 30 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/relu_op_functor.h b/tensorflow/core/kernels/relu_op_functor.h index e564da335a..f917142a12 100644 --- a/tensorflow/core/kernels/relu_op_functor.h +++ b/tensorflow/core/kernels/relu_op_functor.h @@ -91,6 +91,36 @@ struct Relu6Grad { } }; +// Functor used by LeakyReluOp to do the computations. +template <typename Device, typename T> +struct LeakyRelu { + // Computes LeakyRelu activation. + // + // features: any shape. + // activations: same shape as "features". + void operator()(const Device& d, typename TTypes<T>::ConstTensor features, + T alpha, typename TTypes<T>::Tensor activations) { + activations.device(d) = features.cwiseMax(features * alpha); + } +}; + +// Functor used by LeakyReluGradOp to do the computations. +template <typename Device, typename T> +struct LeakyReluGrad { + // Computes LeakyReluGrad backprops. + // + // gradients: gradients backpropagated to the LeakyRelu op. + // features: either the inputs that were passed to the LeakyRelu or, or its + // outputs (using either one yields the same result here). + // backprops: gradients to backpropagate to the LeakyRelu inputs. + void operator()(const Device& d, typename TTypes<T>::ConstTensor gradients, + typename TTypes<T>::ConstTensor features, T alpha, + typename TTypes<T>::Tensor backprops) { + backprops.device(d) = + (features > static_cast<T>(0)).select(gradients, gradients * alpha); + } +}; + // Functor used by EluOp to do the computations. template <typename Device, typename T> struct Elu { |