From 01fd4096d395e7b816459f571bf2328c8435cc37 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 10 Jul 2018 13:16:38 -0700 Subject: Fuse computations into the Tensor contractions using output kernel --- .../Eigen/CXX11/src/Tensor/TensorContraction.h | 113 +++++++++++++++++---- 1 file changed, 92 insertions(+), 21 deletions(-) (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h') diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h index 979fcf4d9..85126a127 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h @@ -85,8 +85,8 @@ template #endif -template -struct traits > +template +struct traits > { // Type promotion to handle the case where the types of the lhs and the rhs are different. typedef typename gebp_traits::type, @@ -112,23 +112,24 @@ struct traits > }; }; -template -struct eval, Eigen::Dense> +template +struct eval, Eigen::Dense> { - typedef const TensorContractionOp& type; + typedef const TensorContractionOp& type; }; -template -struct nested, 1, typename eval >::type> +template +struct nested, 1, typename eval >::type> { - typedef TensorContractionOp type; + typedef TensorContractionOp type; }; -template -struct traits, Device_> > { +template +struct traits, Device_> > { typedef Indices_ Indices; typedef LeftArgType_ LeftArgType; typedef RightArgType_ RightArgType; + typedef OutputKernelType_ OutputKernelType; typedef Device_ Device; // From NumDims below. @@ -137,8 +138,52 @@ struct traits -class TensorContractionOp : public TensorBase, ReadOnlyAccessors> +// Tensor contraction params that should enable to get from output matrix +// 2-dimensional coordinates to the output tensor dimensions. +struct TensorContractionParams { + // TensorContraction evaluator assumes that both tensors are in ColMajor + // layout, if tensors are in RowMajor evaluator swap lhs with rhs. + bool swapped_arguments; +}; + +// Output kernel allows to fuse operations into the tensor contraction. +// +// Examples: +// 1. Elementwise Relu transformation following Conv2D. +// 2. AddBias to the Conv2D output channels dimension. +// +// See expected implementation in NoOpOutputKernel. +struct OutputKernel { + template + using OutputMapper = internal::blas_data_mapper; +}; + +// Output kernel that does absolutely nothing. +struct NoOpOutputKernel { + /** + * Tensor contraction evaluator calls this kernel after finishing each block + * of output matrix. Output blocks belong to the 2-dimensional output tensor. + * + * TensorContractionParams contains contraction dimensions information + * required to map output 2-d space into the expected output tensor space + * (potentially higher dimensional). + * + * \param[in] output_mapper Access to output tensor memory + * \param[in] params Tensor contraction parameters + * \param[in] i Index of a first row available through output_mapper + * \param[in] j Index of a first column available through output_mapper + * \param[in] num_rows Number of available rows + * \param[in] num_cols Number of available columns + */ + template + EIGEN_ALWAYS_INLINE void operator()( + const OutputKernel::OutputMapper& output_mapper, + const TensorContractionParams& params, Index i, Index j, Index num_rows, + Index num_cols) const {} +}; + +template +class TensorContractionOp : public TensorBase, ReadOnlyAccessors> { public: typedef typename Eigen::internal::traits::Scalar Scalar; @@ -149,8 +194,10 @@ class TensorContractionOp : public TensorBase::Index Index; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionOp( - const LhsXprType& lhs, const RhsXprType& rhs, const Indices& dims) - : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_indices(dims) {} + const LhsXprType& lhs, const RhsXprType& rhs, const Indices& dims, + const OutputKernelType& output_kernel = OutputKernelType()) + : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_indices(dims), + m_output_kernel(output_kernel) {} EIGEN_DEVICE_FUNC const Indices& indices() const { return m_indices; } @@ -164,10 +211,14 @@ class TensorContractionOp : public TensorBase::type& rhsExpression() const { return m_rhs_xpr; } + EIGEN_DEVICE_FUNC + const OutputKernelType& outputKernel() const { return m_output_kernel; } + protected: typename LhsXprType::Nested m_lhs_xpr; typename RhsXprType::Nested m_rhs_xpr; const Indices m_indices; + const OutputKernelType m_output_kernel; }; @@ -177,9 +228,10 @@ struct TensorContractionEvaluatorBase typedef typename internal::traits::Indices Indices; typedef typename internal::traits::LeftArgType LeftArgType; typedef typename internal::traits::RightArgType RightArgType; + typedef typename internal::traits::OutputKernelType OutputKernelType; typedef typename internal::traits::Device Device; - typedef TensorContractionOp XprType; + typedef TensorContractionOp XprType; typedef typename internal::remove_const::type Scalar; typedef typename XprType::Index Index; typedef typename XprType::CoeffReturnType CoeffReturnType; @@ -221,6 +273,7 @@ struct TensorContractionEvaluatorBase op.lhsExpression(), op.rhsExpression()), device), m_rightImpl(choose(Cond(Layout) == static_cast(ColMajor)>(), op.rhsExpression(), op.lhsExpression()), device), + m_output_kernel(op.outputKernel()), m_device(device), m_result(NULL) { EIGEN_STATIC_ASSERT((static_cast(TensorEvaluator::Layout) == @@ -391,6 +444,13 @@ struct TensorContractionEvaluatorBase numext::swap(m_dimensions[i], m_dimensions[j]); } } + + // A set of parameters that will allow output kernel to get from output + // tensor dimensions (i, j) into the original tensor dimensions. + // TODO(ezhulenev): Add parameters required to infer output tensor index for + // more complex contractions than 2x2 on internal dimension. + m_tensor_contraction_params = { + /**swapped_arguments=*/static_cast(Layout) == RowMajor}; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } @@ -585,7 +645,15 @@ struct TensorContractionEvaluatorBase // call gebp (matrix kernel) // The parameters here are copied from Eigen's GEMM implementation - gebp(output.getSubMapper(i2, j2), blockA, blockB, actual_mc, actual_kc, actual_nc, Scalar(1), -1, -1, 0, 0); + const auto output_mapper = output.getSubMapper(i2, j2); + gebp(output_mapper, blockA, blockB, actual_mc, actual_kc, actual_nc, + Scalar(1), -1, -1, 0, 0); + + // We are done with this [i2, j2] output block. + if (k2 + kc >= k) { + m_output_kernel(output_mapper, m_tensor_contraction_params, i2, j2, + actual_mc, actual_nc); + } } } } @@ -848,23 +916,26 @@ protected: Index m_j_size; Index m_k_size; + TensorContractionParams m_tensor_contraction_params; + TensorEvaluator m_leftImpl; TensorEvaluator m_rightImpl; const Device& m_device; + OutputKernelType m_output_kernel; Scalar* m_result; bool m_can_use_xsmm; }; // evaluator for default device -template -struct TensorEvaluator, Device> : +template +struct TensorEvaluator, Device> : public TensorContractionEvaluatorBase< - TensorEvaluator, Device> > { - typedef TensorEvaluator, Device> Self; + TensorEvaluator, Device> > { + typedef TensorEvaluator, Device> Self; typedef TensorContractionEvaluatorBase Base; - typedef TensorContractionOp XprType; + typedef TensorContractionOp XprType; typedef typename internal::remove_const::type Scalar; typedef typename XprType::Index Index; typedef typename XprType::CoeffReturnType CoeffReturnType; -- cgit v1.2.3