diff options
Diffstat (limited to 'tensorflow/core/kernels/training_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/training_ops.cc | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc index 172449a998..641c991a7e 100644 --- a/tensorflow/core/kernels/training_ops.cc +++ b/tensorflow/core/kernels/training_ops.cc @@ -46,6 +46,17 @@ struct ApplyGradientDescent<CPUDevice, T> { } }; +#ifdef TENSORFLOW_USE_SYCL +template <typename T> +struct ApplyGradientDescent<SYCLDevice, T> { + void operator()(const SYCLDevice& d, typename TTypes<T>::Flat var, + typename TTypes<T>::ConstScalar lr, + typename TTypes<T>::ConstFlat grad) { + var.device(d) -= grad * lr(); + } +}; +#endif + template <typename T> struct ApplyAdadelta<CPUDevice, T> { void operator()(const CPUDevice& d, typename TTypes<T>::Flat var, @@ -357,6 +368,12 @@ TF_CALL_half(REGISTER_CPU_KERNELS); TF_CALL_float(REGISTER_CPU_KERNELS); TF_CALL_double(REGISTER_CPU_KERNELS); +#ifdef TENSORFLOW_USE_SYCL +#define REGISTER_SYCL_KERNELS(T) REGISTER_KERNELS(SYCL, T); +TF_CALL_float(REGISTER_SYCL_KERNELS); +#undef REGISTER_SYCL_KERNELS +#endif + #if GOOGLE_CUDA // Forward declarations of the functor specializations for GPU. namespace functor { |