aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/batch_norm_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/batch_norm_op.cc')
-rw-r--r--tensorflow/core/kernels/batch_norm_op.cc28
1 files changed, 28 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/batch_norm_op.cc b/tensorflow/core/kernels/batch_norm_op.cc
index 56f4e25fad..d3ed617f71 100644
--- a/tensorflow/core/kernels/batch_norm_op.cc
+++ b/tensorflow/core/kernels/batch_norm_op.cc
@@ -28,6 +28,9 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+#endif // TENSORFLOW_USE_SYCL
template <typename Device, typename T>
class BatchNormOp : public OpKernel {
@@ -201,6 +204,18 @@ TF_CALL_float(REGISTER_GPU_KERNEL);
#endif // GOOGLE_CUDA
+#if TENSORFLOW_USE_SYCL
+#define REGISTER_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalization") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<T>("T"), \
+ BatchNormOp<SYCLDevice, T>);
+
+TF_CALL_float(REGISTER_KERNEL);
+TF_CALL_double(REGISTER_KERNEL);
+#undef REGISTER_KERNEL
+#endif // TENSORFLOW_USE_SYCL
+
#define REGISTER_KERNEL(T) \
REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalizationGrad") \
.Device(DEVICE_CPU) \
@@ -248,4 +263,17 @@ TF_CALL_float(REGISTER_GPU_KERNEL);
#endif // GOOGLE_CUDA
+#if TENSORFLOW_USE_SYCL
+#define REGISTER_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalizationGrad") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<T>("T"), \
+ BatchNormGradOp<SYCLDevice, T>);
+
+TF_CALL_float(REGISTER_KERNEL);
+TF_CALL_double(REGISTER_KERNEL);
+#undef REGISTER_KERNEL
+
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace tensorflow