diff options
Diffstat (limited to 'tensorflow/core/kernels/cwise_ops_sycl_common.h')
-rw-r--r-- | tensorflow/core/kernels/cwise_ops_sycl_common.h | 24 |
1 files changed, 6 insertions, 18 deletions
diff --git a/tensorflow/core/kernels/cwise_ops_sycl_common.h b/tensorflow/core/kernels/cwise_ops_sycl_common.h index a0decbce87..3f6ff7303d 100644 --- a/tensorflow/core/kernels/cwise_ops_sycl_common.h +++ b/tensorflow/core/kernels/cwise_ops_sycl_common.h @@ -31,14 +31,6 @@ namespace functor { typedef Eigen::SyclDevice SYCLDevice; -template <typename Index, int N> Eigen::array<Index, N> GenerateArrayOfOnes() { - Eigen::array<Index, N> result; - for (int i = 0; i < N; ++i) { - result[i] = 1; - } - return result; -} - template <typename OUT, typename RHS> void Assign(const SYCLDevice& d, OUT out, RHS rhs) { out.device(d) = rhs; @@ -67,11 +59,9 @@ struct BinaryFunctor<SYCLDevice, Functor, NDIMS, has_errors> { typename Functor::tin_type in, bool* error) { typedef typename Functor::func Binary; constexpr int NumDims = Functor::tin_type::NumDimensions; - typedef typename Functor::tin_type::Scalar T; - typedef typename Functor::tin_type::Index Index; - Eigen::array<Index, NumDims> scalar_dim = GenerateArrayOfOnes<Index, NumDims>(); - Eigen::TensorMap<Eigen::Tensor<T, NumDims, Eigen::RowMajor>> tmp(scalar.data(), scalar_dim); - out.device(d) = tmp.broadcast(in.dimensions()).binaryExpr(in, Binary()); + static_assert(NumDims == 1, "Unexpected size"); + Eigen::Sizes<1> scalar_dim; + out.device(d) = scalar.reshape(scalar_dim).broadcast(in.dimensions()).binaryExpr(in, Binary()); } void Right(const SYCLDevice& d, typename Functor::tout_type out, @@ -79,11 +69,9 @@ struct BinaryFunctor<SYCLDevice, Functor, NDIMS, has_errors> { typename Functor::tscalar_type scalar, bool* error) { typedef typename Functor::func Binary; constexpr int NumDims = Functor::tin_type::NumDimensions; - typedef typename Functor::tin_type::Scalar T; - typedef typename Functor::tin_type::Index Index; - Eigen::array<Index, NumDims> scalar_dim = GenerateArrayOfOnes<Index, NumDims>(); - Eigen::TensorMap<Eigen::Tensor<T, NumDims, Eigen::RowMajor>> tmp(scalar.data(), scalar_dim); - out.device(d) = in.binaryExpr(tmp.broadcast(in.dimensions()), Binary()); + static_assert(NumDims == 1, "Unexpected size"); + Eigen::Sizes<1> scalar_dim; + out.device(d) = in.binaryExpr(scalar.reshape(scalar_dim).broadcast(in.dimensions()), Binary()); } void BCast(const SYCLDevice& d, |