aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/dense_update_functor.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/dense_update_functor.h')
-rw-r--r--tensorflow/core/kernels/dense_update_functor.h29
1 files changed, 29 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/dense_update_functor.h b/tensorflow/core/kernels/dense_update_functor.h
index 54b080c83b..4aefe26c54 100644
--- a/tensorflow/core/kernels/dense_update_functor.h
+++ b/tensorflow/core/kernels/dense_update_functor.h
@@ -24,6 +24,9 @@ limitations under the License.
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+#endif // TENSORFLOW_USE_SYCL
enum DenseUpdateType { ADD, SUB, ASSIGN };
@@ -59,6 +62,32 @@ struct DenseUpdate<CPUDevice, T, ASSIGN> {
}
};
+#ifdef TENSORFLOW_USE_SYCL
+template <typename T>
+struct DenseUpdate<SYCLDevice, T, ADD> {
+ void operator()(const SYCLDevice& d, typename TTypes<T>::Flat params,
+ typename TTypes<T>::ConstFlat update) {
+ params.device(d) += update;
+ }
+};
+
+template <typename T>
+struct DenseUpdate<SYCLDevice, T, SUB> {
+ void operator()(const SYCLDevice& d, typename TTypes<T>::Flat params,
+ typename TTypes<T>::ConstFlat update) {
+ params.device(d) -= update;
+ }
+};
+
+template <typename T>
+struct DenseUpdate<SYCLDevice, T, ASSIGN> {
+ void operator()(const SYCLDevice& d, typename TTypes<T>::Flat params,
+ typename TTypes<T>::ConstFlat update) {
+ params.device(d) = update;
+ }
+};
+#endif // TENSORFLOW_USE_SYCL
+
} // end namespace functor
} // end namespace tensorflow