diff options
Diffstat (limited to 'tensorflow/core/kernels/relu_op_functor.h')
-rw-r--r-- | tensorflow/core/kernels/relu_op_functor.h | 40 |
1 files changed, 40 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/relu_op_functor.h b/tensorflow/core/kernels/relu_op_functor.h index 633522920c..9577b963c6 100644 --- a/tensorflow/core/kernels/relu_op_functor.h +++ b/tensorflow/core/kernels/relu_op_functor.h @@ -125,6 +125,46 @@ struct EluGrad { } }; +// Functor used by SeluOp to do the computations. +template <typename Device, typename T> +struct Selu { + // Computes Selu activation. + // + // features: any shape. + // activations: same shape as "features". + void operator()(const Device& d, typename TTypes<T>::ConstTensor features, + typename TTypes<T>::Tensor activations) { + // features.constant(?) + const auto scale = static_cast<T>(1.0507009873554804934193349852946); + const auto scale_alpha = static_cast<T>(1.7580993408473768599402175208123); + const auto one = static_cast<T>(1); + const auto zero = static_cast<T>(0); + activations.device(d) = + (features < zero) + .select(scale_alpha * (features.exp() - features.constant(one)), + scale * features); + } +}; + +// Functor used by SeluGradOp to do the computations. +template <typename Device, typename T> +struct SeluGrad { + // Computes SeluGrad backprops. + // + // gradients: gradients backpropagated to the Selu op. + // activations: outputs of the Selu op. + // backprops: gradients to backpropagate to the Selu inputs. + void operator()(const Device& d, typename TTypes<T>::ConstTensor gradients, + typename TTypes<T>::ConstTensor activations, + typename TTypes<T>::Tensor backprops) { + const auto scale = static_cast<T>(1.0507009873554804934193349852946); + const auto scale_alpha = static_cast<T>(1.7580993408473768599402175208123); + backprops.device(d) = + (activations < static_cast<T>(0)).select( + gradients * (activations + scale_alpha), gradients * scale); + } +}; + } // namespace functor } // namespace tensorflow |