diff options
Diffstat (limited to 'tensorflow/core/kernels/training_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/training_ops.cc | 19 |
1 files changed, 16 insertions, 3 deletions
diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc index 600c54bdd2..4c2caa435f 100644 --- a/tensorflow/core/kernels/training_ops.cc +++ b/tensorflow/core/kernels/training_ops.cc @@ -25,6 +25,7 @@ namespace tensorflow { using CPUDevice = Eigen::ThreadPoolDevice; using GPUDevice = Eigen::GpuDevice; +using SYCLDevice = Eigen::SyclDevice; namespace { template <class T> @@ -220,9 +221,9 @@ struct ApplyMomentum<CPUDevice, T> { } }; -template <typename T> -struct ApplyAdam<CPUDevice, T> { - void operator()(const CPUDevice& d, typename TTypes<T>::Flat var, +template <typename Device, typename T> +struct ApplyAdamNonCuda { + void operator()(const Device& d, typename TTypes<T>::Flat var, typename TTypes<T>::Flat m, typename TTypes<T>::Flat v, typename TTypes<T>::ConstScalar beta1_power, typename TTypes<T>::ConstScalar beta2_power, @@ -240,6 +241,12 @@ struct ApplyAdam<CPUDevice, T> { }; template <typename T> +struct ApplyAdam<CPUDevice, T> : ApplyAdamNonCuda<CPUDevice, T> {}; +template <typename T> +struct ApplyAdam<SYCLDevice, T> : ApplyAdamNonCuda<SYCLDevice, T> {}; + + +template <typename T> struct ApplyRMSProp<CPUDevice, T> { void operator()(const CPUDevice& d, typename TTypes<T>::Flat var, typename TTypes<T>::Flat ms, typename TTypes<T>::Flat mom, @@ -2139,6 +2146,12 @@ TF_CALL_half(REGISTER_CPU_KERNELS); TF_CALL_float(REGISTER_CPU_KERNELS); TF_CALL_double(REGISTER_CPU_KERNELS); +#ifdef TENSORFLOW_USE_SYCL +#define REGISTER_SYCL_KERNELS(T) REGISTER_KERNELS(SYCL, T); + +TF_CALL_float(REGISTER_SYCL_KERNELS); +#endif + #if GOOGLE_CUDA // Forward declarations of the functor specializations for GPU. namespace functor { |