aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/relu_op_functor.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/relu_op_functor.h')
-rw-r--r--tensorflow/core/kernels/relu_op_functor.h40
1 files changed, 40 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/relu_op_functor.h b/tensorflow/core/kernels/relu_op_functor.h
index 633522920c..9577b963c6 100644
--- a/tensorflow/core/kernels/relu_op_functor.h
+++ b/tensorflow/core/kernels/relu_op_functor.h
@@ -125,6 +125,46 @@ struct EluGrad {
}
};
+// Functor used by SeluOp to do the computations.
+template <typename Device, typename T>
+struct Selu {
+ // Computes Selu activation.
+ //
+ // features: any shape.
+ // activations: same shape as "features".
+ void operator()(const Device& d, typename TTypes<T>::ConstTensor features,
+ typename TTypes<T>::Tensor activations) {
+ // features.constant(?)
+ const auto scale = static_cast<T>(1.0507009873554804934193349852946);
+ const auto scale_alpha = static_cast<T>(1.7580993408473768599402175208123);
+ const auto one = static_cast<T>(1);
+ const auto zero = static_cast<T>(0);
+ activations.device(d) =
+ (features < zero)
+ .select(scale_alpha * (features.exp() - features.constant(one)),
+ scale * features);
+ }
+};
+
+// Functor used by SeluGradOp to do the computations.
+template <typename Device, typename T>
+struct SeluGrad {
+ // Computes SeluGrad backprops.
+ //
+ // gradients: gradients backpropagated to the Selu op.
+ // activations: outputs of the Selu op.
+ // backprops: gradients to backpropagate to the Selu inputs.
+ void operator()(const Device& d, typename TTypes<T>::ConstTensor gradients,
+ typename TTypes<T>::ConstTensor activations,
+ typename TTypes<T>::Tensor backprops) {
+ const auto scale = static_cast<T>(1.0507009873554804934193349852946);
+ const auto scale_alpha = static_cast<T>(1.7580993408473768599402175208123);
+ backprops.device(d) =
+ (activations < static_cast<T>(0)).select(
+ gradients * (activations + scale_alpha), gradients * scale);
+ }
+};
+
} // namespace functor
} // namespace tensorflow