// This file is part of Eigen, a lightweight C++ template library // for linear algebra. // // Copyright (C) 2014 Eric Martin // // This Source Code Form is subject to the terms of the Mozilla // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H namespace Eigen { /** \class TensorContraction * \ingroup CXX11_Tensor_Module * * \brief Tensor contraction class. * * */ namespace internal { template struct traits > { // Type promotion to handle the case where the types of the lhs and the rhs are different. typedef typename scalar_product_traits::ReturnType Scalar; typedef typename scalar_product_traits::StorageKind, typename traits::StorageKind>::ReturnType StorageKind; typedef typename promote_index_type::Index, typename traits::Index>::type Index; typedef typename LhsXprType::Nested LhsNested; typedef typename RhsXprType::Nested RhsNested; typedef typename remove_reference::type _LhsNested; typedef typename remove_reference::type _RhsNested; // From NumDims below. static const int NumDimensions = traits::NumDimensions + traits::NumDimensions - 2 * array_size::value; static const int Layout = traits::Layout; enum { Flags = 0, }; }; template struct eval, Eigen::Dense> { typedef const TensorContractionOp& type; }; template struct nested, 1, typename eval >::type> { typedef TensorContractionOp type; }; template struct traits, Device_> > { typedef Indices_ Indices; typedef LeftArgType_ LeftArgType; typedef RightArgType_ RightArgType; typedef Device_ Device; // From NumDims below. static const int NumDimensions = traits::NumDimensions + traits::NumDimensions - 2 * array_size::value; }; } // end namespace internal template class TensorContractionOp : public TensorBase > { public: typedef typename Eigen::internal::traits::Scalar Scalar; typedef typename internal::scalar_product_traits::ReturnType CoeffReturnType; typedef typename Eigen::internal::nested::type Nested; typedef typename Eigen::internal::traits::StorageKind StorageKind; typedef typename Eigen::internal::traits::Index Index; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionOp( const LhsXprType& lhs, const RhsXprType& rhs, const Indices& dims) : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_indices(dims) {} EIGEN_DEVICE_FUNC const Indices& indices() const { return m_indices; } /** \returns the nested expressions */ EIGEN_DEVICE_FUNC const typename internal::remove_all::type& lhsExpression() const { return m_lhs_xpr; } EIGEN_DEVICE_FUNC const typename internal::remove_all::type& rhsExpression() const { return m_rhs_xpr; } protected: typename LhsXprType::Nested m_lhs_xpr; typename RhsXprType::Nested m_rhs_xpr; const Indices m_indices; }; template struct TensorContractionEvaluatorBase { typedef typename internal::traits::Indices Indices; typedef typename internal::traits::LeftArgType LeftArgType; typedef typename internal::traits::RightArgType RightArgType; typedef typename internal::traits::Device Device; typedef TensorContractionOp XprType; typedef typename internal::remove_const::type Scalar; typedef typename XprType::Index Index; typedef typename XprType::CoeffReturnType CoeffReturnType; typedef typename PacketType::type PacketReturnType; enum { IsAligned = true, PacketAccess = (internal::packet_traits::size > 1), BlockAccess = false, Layout = TensorEvaluator::Layout, CoordAccess = false, // to be implemented }; // Most of the code is assuming that both input tensors are ColMajor. If the // inputs are RowMajor, we will "cheat" by swapping the LHS and RHS: // If we want to compute A * B = C, where A is LHS and B is RHS, the code // will pretend B is LHS and A is RHS. typedef typename internal::conditional< static_cast(Layout) == static_cast(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType; typedef typename internal::conditional< static_cast(Layout) == static_cast(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType; static const int LDims = internal::array_size::Dimensions>::value; static const int RDims = internal::array_size::Dimensions>::value; static const int ContractDims = internal::array_size::value; static const int NumDims = LDims + RDims - 2 * ContractDims; typedef array left_dim_mapper_t; typedef array right_dim_mapper_t; typedef array contract_t; typedef array left_nocontract_t; typedef array right_nocontract_t; typedef DSizes Dimensions; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionEvaluatorBase(const XprType& op, const Device& device) : m_leftImpl(choose(Cond(Layout) == static_cast(ColMajor)>(), op.lhsExpression(), op.rhsExpression()), device), m_rightImpl(choose(Cond(Layout) == static_cast(ColMajor)>(), op.rhsExpression(), op.lhsExpression()), device), m_device(device), m_result(NULL) { EIGEN_STATIC_ASSERT((static_cast(TensorEvaluator::Layout) == static_cast(TensorEvaluator::Layout)), YOU_MADE_A_PROGRAMMING_MISTAKE); eigen_assert((contract_t::size > 0) && "Must contract on some indices"); DSizes eval_left_dims; DSizes eval_right_dims; array, ContractDims> eval_op_indices; if (static_cast(Layout) == static_cast(ColMajor)) { // For ColMajor, we keep using the existing dimensions for (int i = 0; i < LDims; i++) { eval_left_dims[i] = m_leftImpl.dimensions()[i]; } for (int i = 0; i < RDims; i++) { eval_right_dims[i] = m_rightImpl.dimensions()[i]; } // We keep the pairs of contracting indices. for (int i = 0; i < ContractDims; i++) { eval_op_indices[i].first = op.indices()[i].first; eval_op_indices[i].second = op.indices()[i].second; } } else { // For RowMajor, we need to reverse the existing dimensions for (int i = 0; i < LDims; i++) { eval_left_dims[i] = m_leftImpl.dimensions()[LDims - i - 1]; } for (int i = 0; i < RDims; i++) { eval_right_dims[i] = m_rightImpl.dimensions()[RDims - i - 1]; } // We need to flip all the pairs of contracting indices as well as // reversing the dimensions. for (int i = 0; i < ContractDims; i++) { eval_op_indices[i].first = LDims - 1 - op.indices()[ContractDims - 1 - i].second; eval_op_indices[i].second = RDims - 1 - op.indices()[ContractDims - 1 - i].first; } } array lhs_strides; if (LDims > 0) { lhs_strides[0] = 1; for (int i = 0; i < LDims-1; ++i) { lhs_strides[i+1] = lhs_strides[i] * eval_left_dims[i]; } } array rhs_strides; if (RDims > 0) { rhs_strides[0] = 1; for (int i = 0; i < RDims-1; ++i) { rhs_strides[i+1] = rhs_strides[i] * eval_right_dims[i]; } } if (m_i_strides.size() > 0) m_i_strides[0] = 1; if (m_j_strides.size() > 0) m_j_strides[0] = 1; if (m_k_strides.size() > 0) m_k_strides[0] = 1; m_i_size = 1; m_j_size = 1; m_k_size = 1; // To compute the dimension, we simply concatenate the non-contracting // dimensions of the left and then the right tensor. Additionally, I also // want to compute the cumulative products of the left non-contracting // dimensions, right non-contracting dimensions, and the contracting // dimensions (in the order of the contraction) to aid in the later // computation of tensor indices for matrix indices. m_lhs_inner_dim_contiguous = true; int dim_idx = 0; int nocontract_idx = 0; for (int i = 0; i < LDims; i++) { // find if we are contracting on index i of left tensor bool contracting = false; for (int j = 0; j < ContractDims; j++) { if (eval_op_indices[j].first == i) { contracting = true; break; } } if (!contracting) { // add dimension size to output dimensions m_dimensions[dim_idx] = eval_left_dims[i]; m_left_nocontract_strides[nocontract_idx] = lhs_strides[i]; if (dim_idx != i) { m_lhs_inner_dim_contiguous = false; } if (nocontract_idx+1 < internal::array_size::value) { m_i_strides[nocontract_idx+1] = m_i_strides[nocontract_idx] * eval_left_dims[i]; } else { m_i_size = m_i_strides[nocontract_idx] * eval_left_dims[i]; } dim_idx++; nocontract_idx++; } } nocontract_idx = 0; for (int i = 0; i < RDims; i++) { bool contracting = false; // find if we are contracting on index i of right tensor for (int j = 0; j < ContractDims; j++) { if (eval_op_indices[j].second == i) { contracting = true; break; } } if (!contracting) { m_dimensions[dim_idx] = eval_right_dims[i]; if (nocontract_idx+1 < internal::array_size::value) { m_j_strides[nocontract_idx+1] = m_j_strides[nocontract_idx] * eval_right_dims[i]; } else { m_j_size = m_j_strides[nocontract_idx] * eval_right_dims[i]; } m_right_nocontract_strides[nocontract_idx] = rhs_strides[i]; dim_idx++; nocontract_idx++; } } // now build contraction cumprod. We assumed above that non-contracting axes // are represented in the same order in the matrix as they are in the tensor. // This is not the case for contracting axes. As the contracting axes must be // of the same size in each tensor, I'll only look at the first tensor here. m_rhs_inner_dim_contiguous = true; m_rhs_inner_dim_reordered = false; for (int i = 0; i < ContractDims; i++) { Index left = eval_op_indices[i].first; Index right = eval_op_indices[i].second; Index size = eval_left_dims[left]; eigen_assert(size == eval_right_dims[right] && "Contraction axes must be same size"); if (i+1 < internal::array_size::value) { m_k_strides[i+1] = m_k_strides[i] * size; } else { m_k_size = m_k_strides[i] * size; } m_left_contracting_strides[i] = lhs_strides[left]; m_right_contracting_strides[i] = rhs_strides[right]; if (i > 0 && right < eval_op_indices[i-1].second) { m_rhs_inner_dim_reordered = true; } if (right != i) { m_rhs_inner_dim_contiguous = false; } } // If the layout is RowMajor, we need to reverse the m_dimensions if (static_cast(Layout) == static_cast(RowMajor)) { for (int i = 0, j = NumDims - 1; i < j; i++, j--) { numext::swap(m_dimensions[i], m_dimensions[j]); } } } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* data) { m_leftImpl.evalSubExprsIfNeeded(NULL); m_rightImpl.evalSubExprsIfNeeded(NULL); if (data) { evalTo(data); return false; } else { m_result = static_cast(m_device.allocate(dimensions().TotalSize() * sizeof(Scalar))); evalTo(m_result); return true; } } EIGEN_DEVICE_FUNC void evalTo(Scalar* buffer) const { if (this->m_lhs_inner_dim_contiguous) { if (this->m_rhs_inner_dim_contiguous) { if (this->m_rhs_inner_dim_reordered) { static_cast(this)->template evalProduct(buffer); } else { static_cast(this)->template evalProduct(buffer); } } else { if (this->m_rhs_inner_dim_reordered) { static_cast(this)->template evalProduct(buffer); } else { static_cast(this)->template evalProduct(buffer); } } } else { if (this->m_rhs_inner_dim_contiguous) { if (this->m_rhs_inner_dim_reordered) { static_cast(this)->template evalProduct(buffer); } else { static_cast(this)->template evalProduct(buffer); } } else { if (this->m_rhs_inner_dim_reordered) { static_cast(this)->template evalProduct(buffer); } else { static_cast(this)->template evalProduct(buffer); } } } } template void evalGemv(Scalar* buffer) const { const Index rows = m_i_size; const Index cols = m_k_size; typedef typename internal::remove_const::type LhsScalar; typedef typename internal::remove_const::type RhsScalar; typedef TensorEvaluator LeftEvaluator; typedef TensorEvaluator RightEvaluator; const int lhs_packet_size = PacketType::size; const int rhs_packet_size = PacketType::size; typedef internal::TensorContractionInputMapper LhsMapper; typedef internal::TensorContractionInputMapper RhsMapper; LhsMapper lhs(m_leftImpl, m_left_nocontract_strides, m_i_strides, m_left_contracting_strides, m_k_strides); RhsMapper rhs(m_rightImpl, m_right_nocontract_strides, m_j_strides, m_right_contracting_strides, m_k_strides); const RhsScalar alpha(1); const Index resIncr(1); // zero out the result buffer (which must be of size at least rows * sizeof(Scalar) m_device.memset(buffer, 0, rows * sizeof(Scalar)); internal::general_matrix_vector_product::run( rows, cols, lhs, rhs, buffer, resIncr, alpha); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { m_leftImpl.cleanup(); m_rightImpl.cleanup(); if (m_result != NULL) { m_device.deallocate(m_result); m_result = NULL; } } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { return m_result[index]; } template EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const { return internal::ploadt(m_result + index); } EIGEN_DEVICE_FUNC Scalar* data() const { return m_result; } protected: // Note: nvcc doesn't like implicit copy constructor. If this is needed anywhere, // then we'll have to write an explicit copy constructor... //TensorContractionEvaluatorBase(const TensorContractionEvaluatorBase&); TensorContractionEvaluatorBase& operator = (const TensorContractionEvaluatorBase&); Dimensions m_dimensions; contract_t m_k_strides; contract_t m_left_contracting_strides; contract_t m_right_contracting_strides; bool m_lhs_inner_dim_contiguous; bool m_rhs_inner_dim_contiguous; bool m_rhs_inner_dim_reordered; left_nocontract_t m_i_strides; right_nocontract_t m_j_strides; left_nocontract_t m_left_nocontract_strides; right_nocontract_t m_right_nocontract_strides; Index m_i_size; Index m_j_size; Index m_k_size; TensorEvaluator m_leftImpl; TensorEvaluator m_rightImpl; const Device& m_device; Scalar* m_result; }; // evaluator for default device template struct TensorEvaluator, Device> : public TensorContractionEvaluatorBase< TensorEvaluator, Device> > { typedef TensorEvaluator, Device> Self; typedef TensorContractionEvaluatorBase Base; typedef TensorContractionOp XprType; typedef typename internal::remove_const::type Scalar; typedef typename XprType::Index Index; typedef typename XprType::CoeffReturnType CoeffReturnType; typedef typename PacketType::type PacketReturnType; enum { Layout = TensorEvaluator::Layout, }; // Most of the code is assuming that both input tensors are ColMajor. If the // inputs are RowMajor, we will "cheat" by swapping the LHS and RHS: // If we want to compute A * B = C, where A is LHS and B is RHS, the code // will pretend B is LHS and A is RHS. typedef typename internal::conditional< static_cast(Layout) == static_cast(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType; typedef typename internal::conditional< static_cast(Layout) == static_cast(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType; static const int LDims = internal::array_size::Dimensions>::value; static const int RDims = internal::array_size::Dimensions>::value; static const int ContractDims = internal::array_size::value; typedef array left_dim_mapper_t; typedef array right_dim_mapper_t; typedef array contract_t; typedef array left_nocontract_t; typedef array right_nocontract_t; static const int NumDims = LDims + RDims - 2 * ContractDims; // Could we use NumDimensions here? typedef DSizes Dimensions; EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) : Base(op, device) { } template void evalProduct(Scalar* buffer) const { if (this->m_j_size == 1) { this->template evalGemv(buffer); return; } evalGemm(buffer); } template EIGEN_DEVICE_FUNC void evalGemm(Scalar* buffer) const { // columns in left side, rows in right side const Index k = this->m_k_size; // rows in left side const Index m = this->m_i_size; // columns in right side const Index n = this->m_j_size; // zero out the result buffer (which must be of size at least m * n * sizeof(Scalar) this->m_device.memset(buffer, 0, m * n * sizeof(Scalar)); // define mr, nr, and all of my data mapper types typedef typename internal::remove_const::type LhsScalar; typedef typename internal::remove_const::type RhsScalar; typedef typename internal::gebp_traits Traits; const Index nr = Traits::nr; const Index mr = Traits::mr; typedef TensorEvaluator LeftEvaluator; typedef TensorEvaluator RightEvaluator; const int lhs_packet_size = internal::packet_traits::size; const int rhs_packet_size = internal::packet_traits::size; typedef internal::TensorContractionInputMapper LhsMapper; typedef internal::TensorContractionInputMapper RhsMapper; typedef internal::blas_data_mapper OutputMapper; // declare GEBP packing and kernel structs // TODO: packing could be faster sometimes if we supported row major tensor mappers internal::gemm_pack_lhs pack_lhs; internal::gemm_pack_rhs pack_rhs; // TODO: replace false, false with conjugate values? internal::gebp_kernel gebp; // initialize data mappers LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides, this->m_left_contracting_strides, this->m_k_strides); RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides, this->m_right_contracting_strides, this->m_k_strides); OutputMapper output(buffer, m); // TODO: refine arguments here (am I row or col major, etc) typedef typename internal::gemm_blocking_space BlockingType; // compute block sizes (which depend on number of threads) // last parameter is true to use L3 blocking, 2nd to last parameter is 1 to // indicate 1 thread BlockingType blocking(m, n, k, 1, true); const Index kc = blocking.kc(); const Index mc = (std::min)(m, blocking.mc()); const Index nc = (std::min)(n, blocking.nc()); // sizes of submatrices to live in cache. see Goto paper. int sizeA = blocking.mc() * kc; int sizeB = kc * blocking.nc(); // note: m_device.allocate should return 16 byte aligned pointers, but if blockA and blockB // aren't 16 byte aligned segfaults will happen due to SIMD instructions LhsScalar* blockA = static_cast(this->m_device.allocate(sizeA * sizeof(LhsScalar))); RhsScalar* blockB = static_cast(this->m_device.allocate(sizeB * sizeof(RhsScalar))); for(Index i2=0; i2m_device.deallocate(blockA); this->m_device.deallocate(blockB); } }; } // end namespace Eigen #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H