aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/adjust_contrast_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/adjust_contrast_op.cc')
-rw-r--r--tensorflow/core/kernels/adjust_contrast_op.cc24
1 files changed, 24 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/adjust_contrast_op.cc b/tensorflow/core/kernels/adjust_contrast_op.cc
index c8f12f91a6..37976f7183 100644
--- a/tensorflow/core/kernels/adjust_contrast_op.cc
+++ b/tensorflow/core/kernels/adjust_contrast_op.cc
@@ -31,6 +31,9 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+#endif
// AdjustContrastOp is deprecated as of GraphDef version >= 2
@@ -410,4 +413,25 @@ REGISTER_KERNEL_BUILDER(Name("AdjustContrastv2").Device(DEVICE_GPU),
AdjustContrastOpv2<GPUDevice>);
#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+template <>
+class AdjustContrastOpv2<SYCLDevice> : public AdjustContrastOpV2Base {
+ public:
+ explicit AdjustContrastOpv2(OpKernelConstruction* context)
+ : AdjustContrastOpV2Base(context) {}
+
+ void DoCompute(OpKernelContext* context,
+ const ComputeOptions& options) override {
+ const int64 shape[4] = {options.batch, options.height, options.width,
+ options.channels};
+ functor::AdjustContrastv2<SYCLDevice>()(
+ context->eigen_device<SYCLDevice>(),
+ options.input->shaped<float, 4>(shape), options.factor->scalar<float>(),
+ options.output->shaped<float, 4>(shape));
+ }
+};
+REGISTER_KERNEL_BUILDER(Name("AdjustContrastv2").Device(DEVICE_SYCL),
+ AdjustContrastOpv2<SYCLDevice>);
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace tensorflow