diff options
Diffstat (limited to 'tensorflow/core/kernels/xent_op.cc')
-rw-r--r-- | tensorflow/core/kernels/xent_op.cc | 26 |
1 files changed, 22 insertions, 4 deletions
diff --git a/tensorflow/core/kernels/xent_op.cc b/tensorflow/core/kernels/xent_op.cc index 26f4fb2a2e..dc21cee3a8 100644 --- a/tensorflow/core/kernels/xent_op.cc +++ b/tensorflow/core/kernels/xent_op.cc @@ -28,6 +28,9 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL template <typename Device, typename T> class SoftmaxXentWithLogitsOp : public OpKernel { @@ -74,17 +77,25 @@ class SoftmaxXentWithLogitsOp : public OpKernel { // Partial specialization for a CPUDevice, that uses the Eigen implementation // from XentEigenImpl. namespace functor { -template <typename T> -struct XentFunctor<CPUDevice, T> { - void operator()(const CPUDevice& d, typename TTypes<T>::ConstMatrix logits, +template <typename Device, typename T> +struct XentFunctorBase { + void operator()(const Device& d, typename TTypes<T>::ConstMatrix logits, typename TTypes<T>::ConstMatrix labels, typename TTypes<T>::Matrix scratch, typename TTypes<T>::Vec loss, typename TTypes<T>::Matrix backprop) { - XentEigenImpl<CPUDevice, T>::Compute(d, logits, labels, scratch, loss, + XentEigenImpl<Device, T>::Compute(d, logits, labels, scratch, loss, backprop); } }; + +template <typename T> +struct XentFunctor<CPUDevice, T> : XentFunctorBase<CPUDevice, T> {}; + +#ifdef TENSORFLOW_USE_SYCL +template <typename T> +struct XentFunctor<SYCLDevice, T> : XentFunctorBase<SYCLDevice, T> {}; +#endif // TENSORFLOW_USE_SYCL } // namespace functor #define REGISTER_CPU(T) \ @@ -111,4 +122,11 @@ REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits") SoftmaxXentWithLogitsOp<GPUDevice, double>); #endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL +REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits") + .Device(DEVICE_SYCL) + .TypeConstraint<float>("T"), + SoftmaxXentWithLogitsOp<SYCLDevice, float>); +#endif // TENSORFLOW_USE_SYCL + } // namespace tensorflow |