diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h | 32 |
1 files changed, 27 insertions, 5 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h b/unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h index d06f40cd8..c0f33ba2d 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h @@ -119,6 +119,12 @@ struct TensorEvaluator<const TensorIndexTupleOp<ArgType>, Device> EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; } +#ifdef EIGEN_USE_SYCL + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorEvaluator<ArgType, Device>& impl() const { + return m_impl; + } +#endif + protected: TensorEvaluator<ArgType, Device> m_impl; }; @@ -172,7 +178,7 @@ class TensorTupleReducerOp : public TensorBase<TensorTupleReducerOp<ReduceOp, Di EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorTupleReducerOp(const XprType& expr, const ReduceOp& reduce_op, - const int return_dim, + const Index return_dim, const Dims& reduce_dims) : m_xpr(expr), m_reduce_op(reduce_op), m_return_dim(return_dim), m_reduce_dims(reduce_dims) {} @@ -187,12 +193,12 @@ class TensorTupleReducerOp : public TensorBase<TensorTupleReducerOp<ReduceOp, Di const Dims& reduce_dims() const { return m_reduce_dims; } EIGEN_DEVICE_FUNC - int return_dim() const { return m_return_dim; } + Index return_dim() const { return m_return_dim; } protected: typename XprType::Nested m_xpr; const ReduceOp m_reduce_op; - const int m_return_dim; + const Index m_return_dim; const Dims m_reduce_dims; }; @@ -222,7 +228,11 @@ struct TensorEvaluator<const TensorTupleReducerOp<ReduceOp, Dims, ArgType>, Devi EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) : m_orig_impl(op.expression(), device), m_impl(op.expression().index_tuples().reduce(op.reduce_dims(), op.reduce_op()), device), - m_return_dim(op.return_dim()) { + m_return_dim(op.return_dim()) +#ifdef EIGEN_USE_SYCL + ,m_device(device) +#endif + { gen_strides(m_orig_impl.dimensions(), m_strides); if (Layout == static_cast<int>(ColMajor)) { @@ -252,7 +262,16 @@ struct TensorEvaluator<const TensorTupleReducerOp<ReduceOp, Dims, ArgType>, Devi return (m_return_dim < 0) ? v.first : (v.first % m_stride_mod) / m_stride_div; } + #ifndef EIGEN_USE_SYCL EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; } + #else // following functions are required by sycl + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TupleType* data() const { return m_impl.data(); } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index return_dim() const {return m_return_dim;} + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const StrideDims& strides() const {return m_strides;} + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Index& stride_mod() const {return m_stride_mod;} + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Index& stride_div() const {return m_stride_div;} + const Device& device() const{return m_device;} + #endif EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const { @@ -288,10 +307,13 @@ struct TensorEvaluator<const TensorTupleReducerOp<ReduceOp, Dims, ArgType>, Devi protected: TensorEvaluator<const TensorIndexTupleOp<ArgType>, Device> m_orig_impl; TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexTupleOp<ArgType> >, Device> m_impl; - const int m_return_dim; + const Index m_return_dim; StrideDims m_strides; Index m_stride_mod; Index m_stride_div; +#ifdef EIGEN_USE_SYCL + const Device& m_device; +#endif }; } // end namespace Eigen |