diff options
Diffstat (limited to 'tensorflow/core/kernels/fill_functor.cc')
-rw-r--r-- | tensorflow/core/kernels/fill_functor.cc | 44 |
1 files changed, 44 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/fill_functor.cc b/tensorflow/core/kernels/fill_functor.cc index ea0cc139f3..35d9693f54 100644 --- a/tensorflow/core/kernels/fill_functor.cc +++ b/tensorflow/core/kernels/fill_functor.cc @@ -19,6 +19,7 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/variant_encode_decode.h" @@ -74,6 +75,7 @@ DEFINE_SETZERO_SYCL(int32); DEFINE_SETZERO_SYCL(int64); #undef DEFINE_SETZERO_SYCL #endif // TENSORFLOW_USE_SYCL + template <typename T> void SetOneFunctor<Eigen::ThreadPoolDevice, T>::operator()( const Eigen::ThreadPoolDevice& d, typename TTypes<T>::Flat out) { @@ -112,5 +114,47 @@ DEFINE_SETONE_SYCL(double); #undef DEFINE_SETONE_SYCL #endif // TENSORFLOW_USE_SYCL +template <typename T> +struct FillFunctor<Eigen::ThreadPoolDevice, T> { + void operator()(const Eigen::ThreadPoolDevice& d, typename TTypes<T>::Flat out, + typename TTypes<T>::ConstScalar in) { + out.device(d) = out.constant(in()); + } +}; + +// Explicit instantiations. +#define DEFINE_FILL_CPU(T) \ + template struct FillFunctor<Eigen::ThreadPoolDevice, T>; + +TF_CALL_ALL_TYPES(DEFINE_FILL_CPU); +DEFINE_FILL_CPU(quint8); +DEFINE_FILL_CPU(quint16); +#undef DEFINE_FILL_CPU + +#ifdef TENSORFLOW_USE_SYCL +template <typename T> +struct FillFunctor<Eigen::SyclDevice, T> { + void operator()(const Eigen::SyclDevice& d, typename TTypes<T>::Flat out, + typename TTypes<T>::ConstScalar in) { +#if !defined(EIGEN_HAS_INDEX_LIST) + Eigen::array<int, 1> rank1{1}; +#else + Eigen::IndexList<Eigen::type2index<1> > rank1; +#endif + const int size = out.dimension(0); + Eigen::array<int, 1> broadcast_dims{size}; + + To32Bit(out).device(d) = in.reshape(rank1).broadcast(broadcast_dims); + } +}; + +#define DEFINE_FILL_SYCL(T) \ + template struct FillFunctor<Eigen::SyclDevice, T>; +DEFINE_FILL_SYCL(float); +DEFINE_FILL_SYCL(double); +TF_CALL_INTEGRAL_TYPES(DEFINE_FILL_SYCL) +#undef DEFINE_FILL_SYCL +#endif // TENSORFLOW_USE_SYCL + } // namespace functor } // namespace tensorflow |