aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/xent_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/xent_op.cc')
-rw-r--r--tensorflow/core/kernels/xent_op.cc26
1 files changed, 22 insertions, 4 deletions
diff --git a/tensorflow/core/kernels/xent_op.cc b/tensorflow/core/kernels/xent_op.cc
index 26f4fb2a2e..dc21cee3a8 100644
--- a/tensorflow/core/kernels/xent_op.cc
+++ b/tensorflow/core/kernels/xent_op.cc
@@ -28,6 +28,9 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+#endif // TENSORFLOW_USE_SYCL
template <typename Device, typename T>
class SoftmaxXentWithLogitsOp : public OpKernel {
@@ -74,17 +77,25 @@ class SoftmaxXentWithLogitsOp : public OpKernel {
// Partial specialization for a CPUDevice, that uses the Eigen implementation
// from XentEigenImpl.
namespace functor {
-template <typename T>
-struct XentFunctor<CPUDevice, T> {
- void operator()(const CPUDevice& d, typename TTypes<T>::ConstMatrix logits,
+template <typename Device, typename T>
+struct XentFunctorBase {
+ void operator()(const Device& d, typename TTypes<T>::ConstMatrix logits,
typename TTypes<T>::ConstMatrix labels,
typename TTypes<T>::Matrix scratch,
typename TTypes<T>::Vec loss,
typename TTypes<T>::Matrix backprop) {
- XentEigenImpl<CPUDevice, T>::Compute(d, logits, labels, scratch, loss,
+ XentEigenImpl<Device, T>::Compute(d, logits, labels, scratch, loss,
backprop);
}
};
+
+template <typename T>
+struct XentFunctor<CPUDevice, T> : XentFunctorBase<CPUDevice, T> {};
+
+#ifdef TENSORFLOW_USE_SYCL
+template <typename T>
+struct XentFunctor<SYCLDevice, T> : XentFunctorBase<SYCLDevice, T> {};
+#endif // TENSORFLOW_USE_SYCL
} // namespace functor
#define REGISTER_CPU(T) \
@@ -111,4 +122,11 @@ REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits")
SoftmaxXentWithLogitsOp<GPUDevice, double>);
#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits")
+ .Device(DEVICE_SYCL)
+ .TypeConstraint<float>("T"),
+ SoftmaxXentWithLogitsOp<SYCLDevice, float>);
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace tensorflow