aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/training_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/training_ops.cc')
-rw-r--r--tensorflow/core/kernels/training_ops.cc17
1 files changed, 17 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc
index 172449a998..641c991a7e 100644
--- a/tensorflow/core/kernels/training_ops.cc
+++ b/tensorflow/core/kernels/training_ops.cc
@@ -46,6 +46,17 @@ struct ApplyGradientDescent<CPUDevice, T> {
}
};
+#ifdef TENSORFLOW_USE_SYCL
+template <typename T>
+struct ApplyGradientDescent<SYCLDevice, T> {
+ void operator()(const SYCLDevice& d, typename TTypes<T>::Flat var,
+ typename TTypes<T>::ConstScalar lr,
+ typename TTypes<T>::ConstFlat grad) {
+ var.device(d) -= grad * lr();
+ }
+};
+#endif
+
template <typename T>
struct ApplyAdadelta<CPUDevice, T> {
void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
@@ -357,6 +368,12 @@ TF_CALL_half(REGISTER_CPU_KERNELS);
TF_CALL_float(REGISTER_CPU_KERNELS);
TF_CALL_double(REGISTER_CPU_KERNELS);
+#ifdef TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNELS(T) REGISTER_KERNELS(SYCL, T);
+TF_CALL_float(REGISTER_SYCL_KERNELS);
+#undef REGISTER_SYCL_KERNELS
+#endif
+
#if GOOGLE_CUDA
// Forward declarations of the functor specializations for GPU.
namespace functor {