diff options
Diffstat (limited to 'tensorflow/core/kernels/batch_norm_op.cc')
-rw-r--r-- | tensorflow/core/kernels/batch_norm_op.cc | 28 |
1 files changed, 28 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/batch_norm_op.cc b/tensorflow/core/kernels/batch_norm_op.cc index 56f4e25fad..d3ed617f71 100644 --- a/tensorflow/core/kernels/batch_norm_op.cc +++ b/tensorflow/core/kernels/batch_norm_op.cc @@ -28,6 +28,9 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL template <typename Device, typename T> class BatchNormOp : public OpKernel { @@ -201,6 +204,18 @@ TF_CALL_float(REGISTER_GPU_KERNEL); #endif // GOOGLE_CUDA +#if TENSORFLOW_USE_SYCL +#define REGISTER_KERNEL(T) \ + REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalization") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint<T>("T"), \ + BatchNormOp<SYCLDevice, T>); + +TF_CALL_float(REGISTER_KERNEL); +TF_CALL_double(REGISTER_KERNEL); +#undef REGISTER_KERNEL +#endif // TENSORFLOW_USE_SYCL + #define REGISTER_KERNEL(T) \ REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalizationGrad") \ .Device(DEVICE_CPU) \ @@ -248,4 +263,17 @@ TF_CALL_float(REGISTER_GPU_KERNEL); #endif // GOOGLE_CUDA +#if TENSORFLOW_USE_SYCL +#define REGISTER_KERNEL(T) \ + REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalizationGrad") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint<T>("T"), \ + BatchNormGradOp<SYCLDevice, T>); + +TF_CALL_float(REGISTER_KERNEL); +TF_CALL_double(REGISTER_KERNEL); +#undef REGISTER_KERNEL + +#endif // TENSORFLOW_USE_SYCL + } // namespace tensorflow |