diff options
author | Eugene Zhulenev <ezhulenev@google.com> | 2018-07-10 13:16:38 -0700 |
---|---|---|
committer | Eugene Zhulenev <ezhulenev@google.com> | 2018-07-10 13:16:38 -0700 |
commit | 01fd4096d395e7b816459f571bf2328c8435cc37 (patch) | |
tree | 02b928b34f77c3e63126c3175b6ea06174818f51 /unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h | |
parent | 5539587b1f5b5922b2419b0a4468cf2f393def51 (diff) |
Fuse computations into the Tensor contractions using output kernel
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h | 113 |
1 files changed, 92 insertions, 21 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h index 979fcf4d9..85126a127 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h @@ -85,8 +85,8 @@ template<typename LhsScalar, typename RhsScalar, typename Scalar> #endif -template<typename Dimensions, typename LhsXprType, typename RhsXprType> -struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> > +template<typename Dimensions, typename LhsXprType, typename RhsXprType, typename OutputKernelType> +struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType> > { // Type promotion to handle the case where the types of the lhs and the rhs are different. typedef typename gebp_traits<typename remove_const<typename LhsXprType::Scalar>::type, @@ -112,23 +112,24 @@ struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> > }; }; -template<typename Dimensions, typename LhsXprType, typename RhsXprType> -struct eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType>, Eigen::Dense> +template<typename Dimensions, typename LhsXprType, typename RhsXprType, typename OutputKernelType> +struct eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType>, Eigen::Dense> { - typedef const TensorContractionOp<Dimensions, LhsXprType, RhsXprType>& type; + typedef const TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType>& type; }; -template<typename Dimensions, typename LhsXprType, typename RhsXprType> -struct nested<TensorContractionOp<Dimensions, LhsXprType, RhsXprType>, 1, typename eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> >::type> +template<typename Dimensions, typename LhsXprType, typename RhsXprType, typename OutputKernelType> +struct nested<TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType>, 1, typename eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType> >::type> { - typedef TensorContractionOp<Dimensions, LhsXprType, RhsXprType> type; + typedef TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType> type; }; -template<typename Indices_, typename LeftArgType_, typename RightArgType_, typename Device_> -struct traits<TensorEvaluator<const TensorContractionOp<Indices_, LeftArgType_, RightArgType_>, Device_> > { +template<typename Indices_, typename LeftArgType_, typename RightArgType_, typename OutputKernelType_, typename Device_> +struct traits<TensorEvaluator<const TensorContractionOp<Indices_, LeftArgType_, RightArgType_, OutputKernelType_>, Device_> > { typedef Indices_ Indices; typedef LeftArgType_ LeftArgType; typedef RightArgType_ RightArgType; + typedef OutputKernelType_ OutputKernelType; typedef Device_ Device; // From NumDims below. @@ -137,8 +138,52 @@ struct traits<TensorEvaluator<const TensorContractionOp<Indices_, LeftArgType_, } // end namespace internal -template<typename Indices, typename LhsXprType, typename RhsXprType> -class TensorContractionOp : public TensorBase<TensorContractionOp<Indices, LhsXprType, RhsXprType>, ReadOnlyAccessors> +// Tensor contraction params that should enable to get from output matrix +// 2-dimensional coordinates to the output tensor dimensions. +struct TensorContractionParams { + // TensorContraction evaluator assumes that both tensors are in ColMajor + // layout, if tensors are in RowMajor evaluator swap lhs with rhs. + bool swapped_arguments; +}; + +// Output kernel allows to fuse operations into the tensor contraction. +// +// Examples: +// 1. Elementwise Relu transformation following Conv2D. +// 2. AddBias to the Conv2D output channels dimension. +// +// See expected implementation in NoOpOutputKernel. +struct OutputKernel { + template <typename Index, typename Scalar> + using OutputMapper = internal::blas_data_mapper<Scalar, Index, ColMajor>; +}; + +// Output kernel that does absolutely nothing. +struct NoOpOutputKernel { + /** + * Tensor contraction evaluator calls this kernel after finishing each block + * of output matrix. Output blocks belong to the 2-dimensional output tensor. + * + * TensorContractionParams contains contraction dimensions information + * required to map output 2-d space into the expected output tensor space + * (potentially higher dimensional). + * + * \param[in] output_mapper Access to output tensor memory + * \param[in] params Tensor contraction parameters + * \param[in] i Index of a first row available through output_mapper + * \param[in] j Index of a first column available through output_mapper + * \param[in] num_rows Number of available rows + * \param[in] num_cols Number of available columns + */ + template <typename Index, typename Scalar> + EIGEN_ALWAYS_INLINE void operator()( + const OutputKernel::OutputMapper<Index, Scalar>& output_mapper, + const TensorContractionParams& params, Index i, Index j, Index num_rows, + Index num_cols) const {} +}; + +template<typename Indices, typename LhsXprType, typename RhsXprType, typename OutputKernelType> +class TensorContractionOp : public TensorBase<TensorContractionOp<Indices, LhsXprType, RhsXprType, OutputKernelType>, ReadOnlyAccessors> { public: typedef typename Eigen::internal::traits<TensorContractionOp>::Scalar Scalar; @@ -149,8 +194,10 @@ class TensorContractionOp : public TensorBase<TensorContractionOp<Indices, LhsXp typedef typename Eigen::internal::traits<TensorContractionOp>::Index Index; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionOp( - const LhsXprType& lhs, const RhsXprType& rhs, const Indices& dims) - : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_indices(dims) {} + const LhsXprType& lhs, const RhsXprType& rhs, const Indices& dims, + const OutputKernelType& output_kernel = OutputKernelType()) + : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_indices(dims), + m_output_kernel(output_kernel) {} EIGEN_DEVICE_FUNC const Indices& indices() const { return m_indices; } @@ -164,10 +211,14 @@ class TensorContractionOp : public TensorBase<TensorContractionOp<Indices, LhsXp const typename internal::remove_all<typename RhsXprType::Nested>::type& rhsExpression() const { return m_rhs_xpr; } + EIGEN_DEVICE_FUNC + const OutputKernelType& outputKernel() const { return m_output_kernel; } + protected: typename LhsXprType::Nested m_lhs_xpr; typename RhsXprType::Nested m_rhs_xpr; const Indices m_indices; + const OutputKernelType m_output_kernel; }; @@ -177,9 +228,10 @@ struct TensorContractionEvaluatorBase typedef typename internal::traits<Derived>::Indices Indices; typedef typename internal::traits<Derived>::LeftArgType LeftArgType; typedef typename internal::traits<Derived>::RightArgType RightArgType; + typedef typename internal::traits<Derived>::OutputKernelType OutputKernelType; typedef typename internal::traits<Derived>::Device Device; - typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType; + typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType; typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar; typedef typename XprType::Index Index; typedef typename XprType::CoeffReturnType CoeffReturnType; @@ -221,6 +273,7 @@ struct TensorContractionEvaluatorBase op.lhsExpression(), op.rhsExpression()), device), m_rightImpl(choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(), op.rhsExpression(), op.lhsExpression()), device), + m_output_kernel(op.outputKernel()), m_device(device), m_result(NULL) { EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<LeftArgType, Device>::Layout) == @@ -391,6 +444,13 @@ struct TensorContractionEvaluatorBase numext::swap(m_dimensions[i], m_dimensions[j]); } } + + // A set of parameters that will allow output kernel to get from output + // tensor dimensions (i, j) into the original tensor dimensions. + // TODO(ezhulenev): Add parameters required to infer output tensor index for + // more complex contractions than 2x2 on internal dimension. + m_tensor_contraction_params = { + /**swapped_arguments=*/static_cast<int>(Layout) == RowMajor}; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } @@ -585,7 +645,15 @@ struct TensorContractionEvaluatorBase // call gebp (matrix kernel) // The parameters here are copied from Eigen's GEMM implementation - gebp(output.getSubMapper(i2, j2), blockA, blockB, actual_mc, actual_kc, actual_nc, Scalar(1), -1, -1, 0, 0); + const auto output_mapper = output.getSubMapper(i2, j2); + gebp(output_mapper, blockA, blockB, actual_mc, actual_kc, actual_nc, + Scalar(1), -1, -1, 0, 0); + + // We are done with this [i2, j2] output block. + if (k2 + kc >= k) { + m_output_kernel(output_mapper, m_tensor_contraction_params, i2, j2, + actual_mc, actual_nc); + } } } } @@ -848,23 +916,26 @@ protected: Index m_j_size; Index m_k_size; + TensorContractionParams m_tensor_contraction_params; + TensorEvaluator<EvalLeftArgType, Device> m_leftImpl; TensorEvaluator<EvalRightArgType, Device> m_rightImpl; const Device& m_device; + OutputKernelType m_output_kernel; Scalar* m_result; bool m_can_use_xsmm; }; // evaluator for default device -template<typename Indices, typename LeftArgType, typename RightArgType, typename Device> -struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> : +template<typename Indices, typename LeftArgType, typename RightArgType, typename OutputKernelType, typename Device> +struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> : public TensorContractionEvaluatorBase< - TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> > { - typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> Self; + TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> > { + typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> Self; typedef TensorContractionEvaluatorBase<Self> Base; - typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType; + typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType; typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar; typedef typename XprType::Index Index; typedef typename XprType::CoeffReturnType CoeffReturnType; |