From 925fb6b93710b95082ba44d30405289dff3707eb Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Tue, 10 Jun 2014 09:14:44 -0700 Subject: TensorEval are now typed on the device: this will make it possible to use partial template specialization to optimize the strategy of each evaluator for each device type. Started work on partial evaluations. --- .../Eigen/CXX11/src/Tensor/TensorConvolution.h | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h') diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h b/unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h index 501e9a522..a554b8260 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h @@ -94,27 +94,27 @@ class TensorConvolutionOp : public TensorBase -struct TensorEvaluator > +template +struct TensorEvaluator, Device> { typedef TensorConvolutionOp XprType; - static const int NumDims = TensorEvaluator::Dimensions::count; + static const int NumDims = TensorEvaluator::Dimensions::count; static const int KernelDims = Indices::size; typedef typename XprType::Index Index; typedef DSizes Dimensions; enum { - IsAligned = TensorEvaluator::IsAligned & TensorEvaluator::IsAligned, + IsAligned = TensorEvaluator::IsAligned & TensorEvaluator::IsAligned, PacketAccess = /*TensorEvaluator::PacketAccess & TensorEvaluator::PacketAccess */ false, }; - TensorEvaluator(const XprType& op) - : m_inputImpl(op.inputExpression()), m_kernelImpl(op.kernelExpression()), m_dimensions(op.inputExpression().dimensions()) + TensorEvaluator(const XprType& op, const Device& device) + : m_inputImpl(op.inputExpression(), device), m_kernelImpl(op.kernelExpression(), device), m_dimensions(op.inputExpression().dimensions()) { - const typename TensorEvaluator::Dimensions& input_dims = m_inputImpl.dimensions(); - const typename TensorEvaluator::Dimensions& kernel_dims = m_kernelImpl.dimensions(); + const typename TensorEvaluator::Dimensions& input_dims = m_inputImpl.dimensions(); + const typename TensorEvaluator::Dimensions& kernel_dims = m_kernelImpl.dimensions(); for (int i = 0; i < NumDims; ++i) { if (i > 0) { @@ -200,8 +200,8 @@ struct TensorEvaluator m_indexStride; array m_kernelStride; Dimensions m_dimensions; - TensorEvaluator m_inputImpl; - TensorEvaluator m_kernelImpl; + TensorEvaluator m_inputImpl; + TensorEvaluator m_kernelImpl; }; -- cgit v1.2.3