diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2014-06-10 09:14:44 -0700 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2014-06-10 09:14:44 -0700 |
commit | 925fb6b93710b95082ba44d30405289dff3707eb (patch) | |
tree | 004ce9af64e2ffdb9148a286938a2710ee8c2607 /unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h | |
parent | a77458a8ff2a83e716add62253eb50ef64980b21 (diff) |
TensorEval are now typed on the device: this will make it possible to use partial template specialization to optimize the strategy of each evaluator for each device type.
Started work on partial evaluations.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h | 20 |
1 files changed, 10 insertions, 10 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h b/unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h index 501e9a522..a554b8260 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h @@ -94,27 +94,27 @@ class TensorConvolutionOp : public TensorBase<TensorConvolutionOp<Indices, Input }; -template<typename Indices, typename InputArgType, typename KernelArgType> -struct TensorEvaluator<const TensorConvolutionOp<Indices, InputArgType, KernelArgType> > +template<typename Indices, typename InputArgType, typename KernelArgType, typename Device> +struct TensorEvaluator<const TensorConvolutionOp<Indices, InputArgType, KernelArgType>, Device> { typedef TensorConvolutionOp<Indices, InputArgType, KernelArgType> XprType; - static const int NumDims = TensorEvaluator<InputArgType>::Dimensions::count; + static const int NumDims = TensorEvaluator<InputArgType, Device>::Dimensions::count; static const int KernelDims = Indices::size; typedef typename XprType::Index Index; typedef DSizes<Index, NumDims> Dimensions; enum { - IsAligned = TensorEvaluator<InputArgType>::IsAligned & TensorEvaluator<KernelArgType>::IsAligned, + IsAligned = TensorEvaluator<InputArgType, Device>::IsAligned & TensorEvaluator<KernelArgType, Device>::IsAligned, PacketAccess = /*TensorEvaluator<InputArgType>::PacketAccess & TensorEvaluator<KernelArgType>::PacketAccess */ false, }; - TensorEvaluator(const XprType& op) - : m_inputImpl(op.inputExpression()), m_kernelImpl(op.kernelExpression()), m_dimensions(op.inputExpression().dimensions()) + TensorEvaluator(const XprType& op, const Device& device) + : m_inputImpl(op.inputExpression(), device), m_kernelImpl(op.kernelExpression(), device), m_dimensions(op.inputExpression().dimensions()) { - const typename TensorEvaluator<InputArgType>::Dimensions& input_dims = m_inputImpl.dimensions(); - const typename TensorEvaluator<KernelArgType>::Dimensions& kernel_dims = m_kernelImpl.dimensions(); + const typename TensorEvaluator<InputArgType, Device>::Dimensions& input_dims = m_inputImpl.dimensions(); + const typename TensorEvaluator<KernelArgType, Device>::Dimensions& kernel_dims = m_kernelImpl.dimensions(); for (int i = 0; i < NumDims; ++i) { if (i > 0) { @@ -200,8 +200,8 @@ struct TensorEvaluator<const TensorConvolutionOp<Indices, InputArgType, KernelAr array<Index, KernelDims> m_indexStride; array<Index, KernelDims> m_kernelStride; Dimensions m_dimensions; - TensorEvaluator<InputArgType> m_inputImpl; - TensorEvaluator<KernelArgType> m_kernelImpl; + TensorEvaluator<InputArgType, Device> m_inputImpl; + TensorEvaluator<KernelArgType, Device> m_kernelImpl; }; |