aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/relu_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/relu_op.cc')
-rw-r--r--tensorflow/core/kernels/relu_op.cc62
1 files changed, 50 insertions, 12 deletions
diff --git a/tensorflow/core/kernels/relu_op.cc b/tensorflow/core/kernels/relu_op.cc
index d8d30e87e2..afad288cc0 100644
--- a/tensorflow/core/kernels/relu_op.cc
+++ b/tensorflow/core/kernels/relu_op.cc
@@ -50,15 +50,21 @@ typedef Eigen::SyclDevice SYCLDevice;
TF_CALL_REAL_NUMBER_TYPES(REGISTER_RELU_KERNELS);
#undef REGISTER_RELU_KERNELS
-#define REGISTER_ELU_KERNELS(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("Elu").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
- EluOp<CPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("EluGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
- EluGradOp<CPUDevice, type>)
-
-// Elu only makes sense with float or double.
+#define REGISTER_ELU_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Elu").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ EluOp<CPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("EluGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ EluGradOp<CPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Selu").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ SeluOp<CPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("SeluGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ SeluGradOp<CPUDevice, type>)
+
+// Elu and Selu only make sense with float or double.
TF_CALL_GPU_NUMBER_TYPES(REGISTER_ELU_KERNELS);
#undef REGISTER_ELU_KERNELS
@@ -103,7 +109,23 @@ namespace functor {
const GPUDevice& d, typename TTypes<T>::ConstTensor gradients, \
typename TTypes<T>::ConstTensor activations, \
typename TTypes<T>::Tensor backprops); \
- extern template struct EluGrad<GPUDevice, T>;
+ extern template struct EluGrad<GPUDevice, T>; \
+ \
+ template <> \
+ void Selu<GPUDevice, T>::operator()( \
+ const GPUDevice& d, \
+ typename TTypes<T>::ConstTensor features, \
+ typename TTypes<T>::Tensor activations); \
+ extern template struct Selu<GPUDevice, T>; \
+ \
+ template <> \
+ void SeluGrad<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T>::ConstTensor gradients, \
+ typename TTypes<T>::ConstTensor activations, \
+ typename TTypes<T>::Tensor backprops); \
+ extern template struct SeluGrad<GPUDevice, T>;
+
+
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
} // namespace functor
@@ -127,7 +149,15 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
EluOp<GPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("EluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
- EluGradOp<GPUDevice, type>)
+ EluGradOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Selu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ SeluOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("SeluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ SeluGradOp<GPUDevice, type>)
+
+
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
#undef REGISTER_GPU_KERNELS
@@ -154,7 +184,15 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
EluOp<SYCLDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("EluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
- EluGradOp<SYCLDevice, type>)
+ EluGradOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Selu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ SeluOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("SeluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ SeluGradOp<SYCLDevice, type>)
+
+
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNELS);
#undef REGISTER_SYCL_KERNELS