aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/training_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/training_ops.cc')
-rw-r--r--tensorflow/core/kernels/training_ops.cc19
1 files changed, 16 insertions, 3 deletions
diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc
index 600c54bdd2..4c2caa435f 100644
--- a/tensorflow/core/kernels/training_ops.cc
+++ b/tensorflow/core/kernels/training_ops.cc
@@ -25,6 +25,7 @@ namespace tensorflow {
using CPUDevice = Eigen::ThreadPoolDevice;
using GPUDevice = Eigen::GpuDevice;
+using SYCLDevice = Eigen::SyclDevice;
namespace {
template <class T>
@@ -220,9 +221,9 @@ struct ApplyMomentum<CPUDevice, T> {
}
};
-template <typename T>
-struct ApplyAdam<CPUDevice, T> {
- void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
+template <typename Device, typename T>
+struct ApplyAdamNonCuda {
+ void operator()(const Device& d, typename TTypes<T>::Flat var,
typename TTypes<T>::Flat m, typename TTypes<T>::Flat v,
typename TTypes<T>::ConstScalar beta1_power,
typename TTypes<T>::ConstScalar beta2_power,
@@ -240,6 +241,12 @@ struct ApplyAdam<CPUDevice, T> {
};
template <typename T>
+struct ApplyAdam<CPUDevice, T> : ApplyAdamNonCuda<CPUDevice, T> {};
+template <typename T>
+struct ApplyAdam<SYCLDevice, T> : ApplyAdamNonCuda<SYCLDevice, T> {};
+
+
+template <typename T>
struct ApplyRMSProp<CPUDevice, T> {
void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
typename TTypes<T>::Flat ms, typename TTypes<T>::Flat mom,
@@ -2139,6 +2146,12 @@ TF_CALL_half(REGISTER_CPU_KERNELS);
TF_CALL_float(REGISTER_CPU_KERNELS);
TF_CALL_double(REGISTER_CPU_KERNELS);
+#ifdef TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNELS(T) REGISTER_KERNELS(SYCL, T);
+
+TF_CALL_float(REGISTER_SYCL_KERNELS);
+#endif
+
#if GOOGLE_CUDA
// Forward declarations of the functor specializations for GPU.
namespace functor {