aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2014-06-10 09:14:44 -0700
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2014-06-10 09:14:44 -0700
commit925fb6b93710b95082ba44d30405289dff3707eb (patch)
tree004ce9af64e2ffdb9148a286938a2710ee8c2607 /unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h
parenta77458a8ff2a83e716add62253eb50ef64980b21 (diff)
TensorEval are now typed on the device: this will make it possible to use partial template specialization to optimize the strategy of each evaluator for each device type.
Started work on partial evaluations.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h20
1 files changed, 10 insertions, 10 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h b/unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h
index 501e9a522..a554b8260 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h
@@ -94,27 +94,27 @@ class TensorConvolutionOp : public TensorBase<TensorConvolutionOp<Indices, Input
};
-template<typename Indices, typename InputArgType, typename KernelArgType>
-struct TensorEvaluator<const TensorConvolutionOp<Indices, InputArgType, KernelArgType> >
+template<typename Indices, typename InputArgType, typename KernelArgType, typename Device>
+struct TensorEvaluator<const TensorConvolutionOp<Indices, InputArgType, KernelArgType>, Device>
{
typedef TensorConvolutionOp<Indices, InputArgType, KernelArgType> XprType;
- static const int NumDims = TensorEvaluator<InputArgType>::Dimensions::count;
+ static const int NumDims = TensorEvaluator<InputArgType, Device>::Dimensions::count;
static const int KernelDims = Indices::size;
typedef typename XprType::Index Index;
typedef DSizes<Index, NumDims> Dimensions;
enum {
- IsAligned = TensorEvaluator<InputArgType>::IsAligned & TensorEvaluator<KernelArgType>::IsAligned,
+ IsAligned = TensorEvaluator<InputArgType, Device>::IsAligned & TensorEvaluator<KernelArgType, Device>::IsAligned,
PacketAccess = /*TensorEvaluator<InputArgType>::PacketAccess & TensorEvaluator<KernelArgType>::PacketAccess */
false,
};
- TensorEvaluator(const XprType& op)
- : m_inputImpl(op.inputExpression()), m_kernelImpl(op.kernelExpression()), m_dimensions(op.inputExpression().dimensions())
+ TensorEvaluator(const XprType& op, const Device& device)
+ : m_inputImpl(op.inputExpression(), device), m_kernelImpl(op.kernelExpression(), device), m_dimensions(op.inputExpression().dimensions())
{
- const typename TensorEvaluator<InputArgType>::Dimensions& input_dims = m_inputImpl.dimensions();
- const typename TensorEvaluator<KernelArgType>::Dimensions& kernel_dims = m_kernelImpl.dimensions();
+ const typename TensorEvaluator<InputArgType, Device>::Dimensions& input_dims = m_inputImpl.dimensions();
+ const typename TensorEvaluator<KernelArgType, Device>::Dimensions& kernel_dims = m_kernelImpl.dimensions();
for (int i = 0; i < NumDims; ++i) {
if (i > 0) {
@@ -200,8 +200,8 @@ struct TensorEvaluator<const TensorConvolutionOp<Indices, InputArgType, KernelAr
array<Index, KernelDims> m_indexStride;
array<Index, KernelDims> m_kernelStride;
Dimensions m_dimensions;
- TensorEvaluator<InputArgType> m_inputImpl;
- TensorEvaluator<KernelArgType> m_kernelImpl;
+ TensorEvaluator<InputArgType, Device> m_inputImpl;
+ TensorEvaluator<KernelArgType, Device> m_kernelImpl;
};