aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/cwise_ops_sycl_common.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/cwise_ops_sycl_common.h')
-rw-r--r--tensorflow/core/kernels/cwise_ops_sycl_common.h24
1 files changed, 6 insertions, 18 deletions
diff --git a/tensorflow/core/kernels/cwise_ops_sycl_common.h b/tensorflow/core/kernels/cwise_ops_sycl_common.h
index a0decbce87..3f6ff7303d 100644
--- a/tensorflow/core/kernels/cwise_ops_sycl_common.h
+++ b/tensorflow/core/kernels/cwise_ops_sycl_common.h
@@ -31,14 +31,6 @@ namespace functor {
typedef Eigen::SyclDevice SYCLDevice;
-template <typename Index, int N> Eigen::array<Index, N> GenerateArrayOfOnes() {
- Eigen::array<Index, N> result;
- for (int i = 0; i < N; ++i) {
- result[i] = 1;
- }
- return result;
-}
-
template <typename OUT, typename RHS>
void Assign(const SYCLDevice& d, OUT out, RHS rhs) {
out.device(d) = rhs;
@@ -67,11 +59,9 @@ struct BinaryFunctor<SYCLDevice, Functor, NDIMS, has_errors> {
typename Functor::tin_type in, bool* error) {
typedef typename Functor::func Binary;
constexpr int NumDims = Functor::tin_type::NumDimensions;
- typedef typename Functor::tin_type::Scalar T;
- typedef typename Functor::tin_type::Index Index;
- Eigen::array<Index, NumDims> scalar_dim = GenerateArrayOfOnes<Index, NumDims>();
- Eigen::TensorMap<Eigen::Tensor<T, NumDims, Eigen::RowMajor>> tmp(scalar.data(), scalar_dim);
- out.device(d) = tmp.broadcast(in.dimensions()).binaryExpr(in, Binary());
+ static_assert(NumDims == 1, "Unexpected size");
+ Eigen::Sizes<1> scalar_dim;
+ out.device(d) = scalar.reshape(scalar_dim).broadcast(in.dimensions()).binaryExpr(in, Binary());
}
void Right(const SYCLDevice& d, typename Functor::tout_type out,
@@ -79,11 +69,9 @@ struct BinaryFunctor<SYCLDevice, Functor, NDIMS, has_errors> {
typename Functor::tscalar_type scalar, bool* error) {
typedef typename Functor::func Binary;
constexpr int NumDims = Functor::tin_type::NumDimensions;
- typedef typename Functor::tin_type::Scalar T;
- typedef typename Functor::tin_type::Index Index;
- Eigen::array<Index, NumDims> scalar_dim = GenerateArrayOfOnes<Index, NumDims>();
- Eigen::TensorMap<Eigen::Tensor<T, NumDims, Eigen::RowMajor>> tmp(scalar.data(), scalar_dim);
- out.device(d) = in.binaryExpr(tmp.broadcast(in.dimensions()), Binary());
+ static_assert(NumDims == 1, "Unexpected size");
+ Eigen::Sizes<1> scalar_dim;
+ out.device(d) = in.binaryExpr(scalar.reshape(scalar_dim).broadcast(in.dimensions()), Binary());
}
void BCast(const SYCLDevice& d,