diff options
Diffstat (limited to 'tensorflow/core/kernels/adjust_contrast_op.cc')
-rw-r--r-- | tensorflow/core/kernels/adjust_contrast_op.cc | 24 |
1 files changed, 24 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/adjust_contrast_op.cc b/tensorflow/core/kernels/adjust_contrast_op.cc index c8f12f91a6..37976f7183 100644 --- a/tensorflow/core/kernels/adjust_contrast_op.cc +++ b/tensorflow/core/kernels/adjust_contrast_op.cc @@ -31,6 +31,9 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // AdjustContrastOp is deprecated as of GraphDef version >= 2 @@ -410,4 +413,25 @@ REGISTER_KERNEL_BUILDER(Name("AdjustContrastv2").Device(DEVICE_GPU), AdjustContrastOpv2<GPUDevice>); #endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL +template <> +class AdjustContrastOpv2<SYCLDevice> : public AdjustContrastOpV2Base { + public: + explicit AdjustContrastOpv2(OpKernelConstruction* context) + : AdjustContrastOpV2Base(context) {} + + void DoCompute(OpKernelContext* context, + const ComputeOptions& options) override { + const int64 shape[4] = {options.batch, options.height, options.width, + options.channels}; + functor::AdjustContrastv2<SYCLDevice>()( + context->eigen_device<SYCLDevice>(), + options.input->shaped<float, 4>(shape), options.factor->scalar<float>(), + options.output->shaped<float, 4>(shape)); + } +}; +REGISTER_KERNEL_BUILDER(Name("AdjustContrastv2").Device(DEVICE_SYCL), + AdjustContrastOpv2<SYCLDevice>); +#endif // TENSORFLOW_USE_SYCL + } // namespace tensorflow |