diff options
Diffstat (limited to 'tensorflow/core/kernels/dense_update_functor.h')
-rw-r--r-- | tensorflow/core/kernels/dense_update_functor.h | 29 |
1 files changed, 29 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/dense_update_functor.h b/tensorflow/core/kernels/dense_update_functor.h index 54b080c83b..4aefe26c54 100644 --- a/tensorflow/core/kernels/dense_update_functor.h +++ b/tensorflow/core/kernels/dense_update_functor.h @@ -24,6 +24,9 @@ limitations under the License. namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL enum DenseUpdateType { ADD, SUB, ASSIGN }; @@ -59,6 +62,32 @@ struct DenseUpdate<CPUDevice, T, ASSIGN> { } }; +#ifdef TENSORFLOW_USE_SYCL +template <typename T> +struct DenseUpdate<SYCLDevice, T, ADD> { + void operator()(const SYCLDevice& d, typename TTypes<T>::Flat params, + typename TTypes<T>::ConstFlat update) { + params.device(d) += update; + } +}; + +template <typename T> +struct DenseUpdate<SYCLDevice, T, SUB> { + void operator()(const SYCLDevice& d, typename TTypes<T>::Flat params, + typename TTypes<T>::ConstFlat update) { + params.device(d) -= update; + } +}; + +template <typename T> +struct DenseUpdate<SYCLDevice, T, ASSIGN> { + void operator()(const SYCLDevice& d, typename TTypes<T>::Flat params, + typename TTypes<T>::ConstFlat update) { + params.device(d) = update; + } +}; +#endif // TENSORFLOW_USE_SYCL + } // end namespace functor } // end namespace tensorflow |