diff options
Diffstat (limited to 'tensorflow/core/kernels/softmax_op.cc')
-rw-r--r-- | tensorflow/core/kernels/softmax_op.cc | 24 |
1 files changed, 20 insertions, 4 deletions
diff --git a/tensorflow/core/kernels/softmax_op.cc b/tensorflow/core/kernels/softmax_op.cc index c7ae93852f..de11de32f1 100644 --- a/tensorflow/core/kernels/softmax_op.cc +++ b/tensorflow/core/kernels/softmax_op.cc @@ -28,17 +28,27 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL // Partial specialization for a CPUDevice, that uses the Eigen implementation // from SoftmaxEigenImpl. namespace functor { -template <typename T> -struct SoftmaxFunctor<CPUDevice, T> { - void operator()(const CPUDevice& d, typename TTypes<T>::ConstMatrix logits, +template <typename Device, typename T> +struct SoftmaxFunctorBase { + void operator()(const Device& d, typename TTypes<T>::ConstMatrix logits, typename TTypes<T>::Matrix softmax, const bool log) { - SoftmaxEigenImpl<CPUDevice, T>::Compute(d, logits, softmax, log); + SoftmaxEigenImpl<Device, T>::Compute(d, logits, softmax, log); } }; +template <typename T> +struct SoftmaxFunctor<CPUDevice, T> : SoftmaxFunctorBase<CPUDevice, T> {}; + +#ifdef TENSORFLOW_USE_SYCL +template <typename T> +struct SoftmaxFunctor<SYCLDevice, T> : SoftmaxFunctorBase<SYCLDevice, T> {}; +#endif // TENSORFLOW_USE_SYCL } // namespace functor #define REGISTER_CPU(T) \ @@ -76,4 +86,10 @@ REGISTER_KERNEL_BUILDER( SoftmaxOp<GPUDevice, float>); #endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL +REGISTER_KERNEL_BUILDER( + Name("Softmax").Device(DEVICE_SYCL).TypeConstraint<float>("T"), + SoftmaxOp<SYCLDevice, float>); +#endif // TENSORFLOW_USE_SYCL + } // namespace tensorflow |