aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/fill_functor.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/fill_functor.cc')
-rw-r--r--tensorflow/core/kernels/fill_functor.cc44
1 files changed, 44 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/fill_functor.cc b/tensorflow/core/kernels/fill_functor.cc
index ea0cc139f3..35d9693f54 100644
--- a/tensorflow/core/kernels/fill_functor.cc
+++ b/tensorflow/core/kernels/fill_functor.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
@@ -74,6 +75,7 @@ DEFINE_SETZERO_SYCL(int32);
DEFINE_SETZERO_SYCL(int64);
#undef DEFINE_SETZERO_SYCL
#endif // TENSORFLOW_USE_SYCL
+
template <typename T>
void SetOneFunctor<Eigen::ThreadPoolDevice, T>::operator()(
const Eigen::ThreadPoolDevice& d, typename TTypes<T>::Flat out) {
@@ -112,5 +114,47 @@ DEFINE_SETONE_SYCL(double);
#undef DEFINE_SETONE_SYCL
#endif // TENSORFLOW_USE_SYCL
+template <typename T>
+struct FillFunctor<Eigen::ThreadPoolDevice, T> {
+ void operator()(const Eigen::ThreadPoolDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstScalar in) {
+ out.device(d) = out.constant(in());
+ }
+};
+
+// Explicit instantiations.
+#define DEFINE_FILL_CPU(T) \
+ template struct FillFunctor<Eigen::ThreadPoolDevice, T>;
+
+TF_CALL_ALL_TYPES(DEFINE_FILL_CPU);
+DEFINE_FILL_CPU(quint8);
+DEFINE_FILL_CPU(quint16);
+#undef DEFINE_FILL_CPU
+
+#ifdef TENSORFLOW_USE_SYCL
+template <typename T>
+struct FillFunctor<Eigen::SyclDevice, T> {
+ void operator()(const Eigen::SyclDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstScalar in) {
+#if !defined(EIGEN_HAS_INDEX_LIST)
+ Eigen::array<int, 1> rank1{1};
+#else
+ Eigen::IndexList<Eigen::type2index<1> > rank1;
+#endif
+ const int size = out.dimension(0);
+ Eigen::array<int, 1> broadcast_dims{size};
+
+ To32Bit(out).device(d) = in.reshape(rank1).broadcast(broadcast_dims);
+ }
+};
+
+#define DEFINE_FILL_SYCL(T) \
+ template struct FillFunctor<Eigen::SyclDevice, T>;
+DEFINE_FILL_SYCL(float);
+DEFINE_FILL_SYCL(double);
+TF_CALL_INTEGRAL_TYPES(DEFINE_FILL_SYCL)
+#undef DEFINE_FILL_SYCL
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace functor
} // namespace tensorflow