aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/softmax_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/softmax_op.cc')
-rw-r--r--tensorflow/core/kernels/softmax_op.cc24
1 files changed, 20 insertions, 4 deletions
diff --git a/tensorflow/core/kernels/softmax_op.cc b/tensorflow/core/kernels/softmax_op.cc
index c7ae93852f..de11de32f1 100644
--- a/tensorflow/core/kernels/softmax_op.cc
+++ b/tensorflow/core/kernels/softmax_op.cc
@@ -28,17 +28,27 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+#endif // TENSORFLOW_USE_SYCL
// Partial specialization for a CPUDevice, that uses the Eigen implementation
// from SoftmaxEigenImpl.
namespace functor {
-template <typename T>
-struct SoftmaxFunctor<CPUDevice, T> {
- void operator()(const CPUDevice& d, typename TTypes<T>::ConstMatrix logits,
+template <typename Device, typename T>
+struct SoftmaxFunctorBase {
+ void operator()(const Device& d, typename TTypes<T>::ConstMatrix logits,
typename TTypes<T>::Matrix softmax, const bool log) {
- SoftmaxEigenImpl<CPUDevice, T>::Compute(d, logits, softmax, log);
+ SoftmaxEigenImpl<Device, T>::Compute(d, logits, softmax, log);
}
};
+template <typename T>
+struct SoftmaxFunctor<CPUDevice, T> : SoftmaxFunctorBase<CPUDevice, T> {};
+
+#ifdef TENSORFLOW_USE_SYCL
+template <typename T>
+struct SoftmaxFunctor<SYCLDevice, T> : SoftmaxFunctorBase<SYCLDevice, T> {};
+#endif // TENSORFLOW_USE_SYCL
} // namespace functor
#define REGISTER_CPU(T) \
@@ -76,4 +86,10 @@ REGISTER_KERNEL_BUILDER(
SoftmaxOp<GPUDevice, float>);
#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+REGISTER_KERNEL_BUILDER(
+ Name("Softmax").Device(DEVICE_SYCL).TypeConstraint<float>("T"),
+ SoftmaxOp<SYCLDevice, float>);
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace tensorflow