diff options
author | Eugene Zhulenev <ezhulenev@google.com> | 2018-07-10 13:16:38 -0700 |
---|---|---|
committer | Eugene Zhulenev <ezhulenev@google.com> | 2018-07-10 13:16:38 -0700 |
commit | 01fd4096d395e7b816459f571bf2328c8435cc37 (patch) | |
tree | 02b928b34f77c3e63126c3175b6ea06174818f51 | |
parent | 5539587b1f5b5922b2419b0a4468cf2f393def51 (diff) |
Fuse computations into the Tensor contractions using output kernel
6 files changed, 248 insertions, 37 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h index bdc1a17a7..97f90f638 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h @@ -517,9 +517,15 @@ class TensorBase<Derived, ReadOnlyAccessors> typedef Eigen::IndexPair<Index> DimensionPair; template<typename OtherDerived, typename Dimensions> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - const TensorContractionOp<const Dimensions, const Derived, const OtherDerived> + const TensorContractionOp<const Dimensions, const Derived, const OtherDerived, const NoOpOutputKernel> contract(const OtherDerived& other, const Dimensions& dims) const { - return TensorContractionOp<const Dimensions, const Derived, const OtherDerived>(derived(), other.derived(), dims); + return TensorContractionOp<const Dimensions, const Derived, const OtherDerived, const NoOpOutputKernel>(derived(), other.derived(), dims); + } + + template<typename OtherDerived, typename Dimensions, typename OutputKernel> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorContractionOp<const Dimensions, const Derived, const OtherDerived, const OutputKernel> + contract(const OtherDerived& other, const Dimensions& dims, const OutputKernel& output_kernel) const { + return TensorContractionOp<const Dimensions, const Derived, const OtherDerived, const OutputKernel>(derived(), other.derived(), dims, output_kernel); } // Convolutions. diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h index 979fcf4d9..85126a127 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h @@ -85,8 +85,8 @@ template<typename LhsScalar, typename RhsScalar, typename Scalar> #endif -template<typename Dimensions, typename LhsXprType, typename RhsXprType> -struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> > +template<typename Dimensions, typename LhsXprType, typename RhsXprType, typename OutputKernelType> +struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType> > { // Type promotion to handle the case where the types of the lhs and the rhs are different. typedef typename gebp_traits<typename remove_const<typename LhsXprType::Scalar>::type, @@ -112,23 +112,24 @@ struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> > }; }; -template<typename Dimensions, typename LhsXprType, typename RhsXprType> -struct eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType>, Eigen::Dense> +template<typename Dimensions, typename LhsXprType, typename RhsXprType, typename OutputKernelType> +struct eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType>, Eigen::Dense> { - typedef const TensorContractionOp<Dimensions, LhsXprType, RhsXprType>& type; + typedef const TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType>& type; }; -template<typename Dimensions, typename LhsXprType, typename RhsXprType> -struct nested<TensorContractionOp<Dimensions, LhsXprType, RhsXprType>, 1, typename eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> >::type> +template<typename Dimensions, typename LhsXprType, typename RhsXprType, typename OutputKernelType> +struct nested<TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType>, 1, typename eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType> >::type> { - typedef TensorContractionOp<Dimensions, LhsXprType, RhsXprType> type; + typedef TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType> type; }; -template<typename Indices_, typename LeftArgType_, typename RightArgType_, typename Device_> -struct traits<TensorEvaluator<const TensorContractionOp<Indices_, LeftArgType_, RightArgType_>, Device_> > { +template<typename Indices_, typename LeftArgType_, typename RightArgType_, typename OutputKernelType_, typename Device_> +struct traits<TensorEvaluator<const TensorContractionOp<Indices_, LeftArgType_, RightArgType_, OutputKernelType_>, Device_> > { typedef Indices_ Indices; typedef LeftArgType_ LeftArgType; typedef RightArgType_ RightArgType; + typedef OutputKernelType_ OutputKernelType; typedef Device_ Device; // From NumDims below. @@ -137,8 +138,52 @@ struct traits<TensorEvaluator<const TensorContractionOp<Indices_, LeftArgType_, } // end namespace internal -template<typename Indices, typename LhsXprType, typename RhsXprType> -class TensorContractionOp : public TensorBase<TensorContractionOp<Indices, LhsXprType, RhsXprType>, ReadOnlyAccessors> +// Tensor contraction params that should enable to get from output matrix +// 2-dimensional coordinates to the output tensor dimensions. +struct TensorContractionParams { + // TensorContraction evaluator assumes that both tensors are in ColMajor + // layout, if tensors are in RowMajor evaluator swap lhs with rhs. + bool swapped_arguments; +}; + +// Output kernel allows to fuse operations into the tensor contraction. +// +// Examples: +// 1. Elementwise Relu transformation following Conv2D. +// 2. AddBias to the Conv2D output channels dimension. +// +// See expected implementation in NoOpOutputKernel. +struct OutputKernel { + template <typename Index, typename Scalar> + using OutputMapper = internal::blas_data_mapper<Scalar, Index, ColMajor>; +}; + +// Output kernel that does absolutely nothing. +struct NoOpOutputKernel { + /** + * Tensor contraction evaluator calls this kernel after finishing each block + * of output matrix. Output blocks belong to the 2-dimensional output tensor. + * + * TensorContractionParams contains contraction dimensions information + * required to map output 2-d space into the expected output tensor space + * (potentially higher dimensional). + * + * \param[in] output_mapper Access to output tensor memory + * \param[in] params Tensor contraction parameters + * \param[in] i Index of a first row available through output_mapper + * \param[in] j Index of a first column available through output_mapper + * \param[in] num_rows Number of available rows + * \param[in] num_cols Number of available columns + */ + template <typename Index, typename Scalar> + EIGEN_ALWAYS_INLINE void operator()( + const OutputKernel::OutputMapper<Index, Scalar>& output_mapper, + const TensorContractionParams& params, Index i, Index j, Index num_rows, + Index num_cols) const {} +}; + +template<typename Indices, typename LhsXprType, typename RhsXprType, typename OutputKernelType> +class TensorContractionOp : public TensorBase<TensorContractionOp<Indices, LhsXprType, RhsXprType, OutputKernelType>, ReadOnlyAccessors> { public: typedef typename Eigen::internal::traits<TensorContractionOp>::Scalar Scalar; @@ -149,8 +194,10 @@ class TensorContractionOp : public TensorBase<TensorContractionOp<Indices, LhsXp typedef typename Eigen::internal::traits<TensorContractionOp>::Index Index; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionOp( - const LhsXprType& lhs, const RhsXprType& rhs, const Indices& dims) - : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_indices(dims) {} + const LhsXprType& lhs, const RhsXprType& rhs, const Indices& dims, + const OutputKernelType& output_kernel = OutputKernelType()) + : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_indices(dims), + m_output_kernel(output_kernel) {} EIGEN_DEVICE_FUNC const Indices& indices() const { return m_indices; } @@ -164,10 +211,14 @@ class TensorContractionOp : public TensorBase<TensorContractionOp<Indices, LhsXp const typename internal::remove_all<typename RhsXprType::Nested>::type& rhsExpression() const { return m_rhs_xpr; } + EIGEN_DEVICE_FUNC + const OutputKernelType& outputKernel() const { return m_output_kernel; } + protected: typename LhsXprType::Nested m_lhs_xpr; typename RhsXprType::Nested m_rhs_xpr; const Indices m_indices; + const OutputKernelType m_output_kernel; }; @@ -177,9 +228,10 @@ struct TensorContractionEvaluatorBase typedef typename internal::traits<Derived>::Indices Indices; typedef typename internal::traits<Derived>::LeftArgType LeftArgType; typedef typename internal::traits<Derived>::RightArgType RightArgType; + typedef typename internal::traits<Derived>::OutputKernelType OutputKernelType; typedef typename internal::traits<Derived>::Device Device; - typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType; + typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType; typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar; typedef typename XprType::Index Index; typedef typename XprType::CoeffReturnType CoeffReturnType; @@ -221,6 +273,7 @@ struct TensorContractionEvaluatorBase op.lhsExpression(), op.rhsExpression()), device), m_rightImpl(choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(), op.rhsExpression(), op.lhsExpression()), device), + m_output_kernel(op.outputKernel()), m_device(device), m_result(NULL) { EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<LeftArgType, Device>::Layout) == @@ -391,6 +444,13 @@ struct TensorContractionEvaluatorBase numext::swap(m_dimensions[i], m_dimensions[j]); } } + + // A set of parameters that will allow output kernel to get from output + // tensor dimensions (i, j) into the original tensor dimensions. + // TODO(ezhulenev): Add parameters required to infer output tensor index for + // more complex contractions than 2x2 on internal dimension. + m_tensor_contraction_params = { + /**swapped_arguments=*/static_cast<int>(Layout) == RowMajor}; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } @@ -585,7 +645,15 @@ struct TensorContractionEvaluatorBase // call gebp (matrix kernel) // The parameters here are copied from Eigen's GEMM implementation - gebp(output.getSubMapper(i2, j2), blockA, blockB, actual_mc, actual_kc, actual_nc, Scalar(1), -1, -1, 0, 0); + const auto output_mapper = output.getSubMapper(i2, j2); + gebp(output_mapper, blockA, blockB, actual_mc, actual_kc, actual_nc, + Scalar(1), -1, -1, 0, 0); + + // We are done with this [i2, j2] output block. + if (k2 + kc >= k) { + m_output_kernel(output_mapper, m_tensor_contraction_params, i2, j2, + actual_mc, actual_nc); + } } } } @@ -848,23 +916,26 @@ protected: Index m_j_size; Index m_k_size; + TensorContractionParams m_tensor_contraction_params; + TensorEvaluator<EvalLeftArgType, Device> m_leftImpl; TensorEvaluator<EvalRightArgType, Device> m_rightImpl; const Device& m_device; + OutputKernelType m_output_kernel; Scalar* m_result; bool m_can_use_xsmm; }; // evaluator for default device -template<typename Indices, typename LeftArgType, typename RightArgType, typename Device> -struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> : +template<typename Indices, typename LeftArgType, typename RightArgType, typename OutputKernelType, typename Device> +struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> : public TensorContractionEvaluatorBase< - TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> > { - typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> Self; + TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> > { + typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> Self; typedef TensorContractionEvaluatorBase<Self> Base; - typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType; + typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType; typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar; typedef typename XprType::Index Index; typedef typename XprType::CoeffReturnType CoeffReturnType; diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h index 3c007b183..d7536bd6a 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h @@ -56,16 +56,16 @@ struct packRhsAndKernelArg { } // end namespace internal #endif // EIGEN_USE_SIMPLE_THREAD_POOL -template<typename Indices, typename LeftArgType, typename RightArgType> -struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, ThreadPoolDevice> : - public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, ThreadPoolDevice> > { +template<typename Indices, typename LeftArgType, typename RightArgType, typename OutputKernelType> +struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, ThreadPoolDevice> : + public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, ThreadPoolDevice> > { typedef ThreadPoolDevice Device; - typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> Self; + typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> Self; typedef TensorContractionEvaluatorBase<Self> Base; - typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType; + typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType; typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar; typedef typename XprType::Index Index; typedef typename XprType::CoeffReturnType CoeffReturnType; @@ -308,7 +308,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT this->m_k_strides); Context<LhsPacker, RhsPacker, GebpKernel, LhsMapper, RhsMapper, - OutputMapper>(this->m_device, num_threads, lhs, rhs, buffer, m, n, + OutputMapper>(this, num_threads, lhs, rhs, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, nn0, shard_by_col, parallel_pack) .run(); @@ -319,16 +319,18 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT typename LhsMapper, typename RhsMapper, typename OutputMapper> class Context { public: - Context(const Device& device, int num_threads, LhsMapper& lhs, + Context(const Self* self, int num_threads, LhsMapper& lhs, RhsMapper& rhs, Scalar* buffer, Index tm, Index tn, Index tk, Index bm, Index bn, Index bk, Index nm, Index nn, Index nk, Index gm, Index gn, Index nm0, Index nn0, bool shard_by_col, bool parallel_pack) - : device_(device), + : device_(self->m_device), lhs_(lhs), rhs_(rhs), buffer_(buffer), output_(buffer, tm), + output_kernel_(self->m_output_kernel), + tensor_contraction_params_(self->m_tensor_contraction_params), num_threads_(num_threads), shard_by_col_(shard_by_col), parallel_pack_(parallel_pack), @@ -420,6 +422,8 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT RhsMapper& rhs_; Scalar* const buffer_; OutputMapper output_; + OutputKernelType output_kernel_; + TensorContractionParams tensor_contraction_params_; const int num_threads_; const bool shard_by_col_; const bool parallel_pack_; @@ -536,19 +540,32 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT const Index mend = m * gm_ + gm(m); if (shard_by_col_) { for (Index n1 = n * gn_; n1 < nend; n1++) { - for (Index m1 = m * gm_; m1 < mend; m1++) - GebpKernel()(output_.getSubMapper(m1 * bm_, n1 * bn_), - packed_lhs_[k % (P - 1)][m1], + for (Index m1 = m * gm_; m1 < mend; m1++) { + const auto output_mapper = output_.getSubMapper(m1 * bm_, n1 * bn_); + GebpKernel()(output_mapper, packed_lhs_[k % (P - 1)][m1], packed_rhs_[k % (P - 1)][n1], bm(m1), bk(k), bn(n1), Scalar(1), -1, -1, 0, 0); + + // We are done with the last task for the [m1, n1] block. + if (k + 1 == nk_) { + output_kernel_(output_mapper, tensor_contraction_params_, + m1 * bm_, n1 * bn_, bm(m1), bn(n1)); + } + } } } else { for (Index m1 = m * gm_; m1 < mend; m1++) for (Index n1 = n * gn_; n1 < nend; n1++) { - GebpKernel()(output_.getSubMapper(m1 * bm_, n1 * bn_), - packed_lhs_[k % (P - 1)][m1], + const auto output_mapper = output_.getSubMapper(m1 * bm_, n1 * bn_); + GebpKernel()(output_mapper, packed_lhs_[k % (P - 1)][m1], packed_rhs_[k % (P - 1)][n1], bm(m1), bk(k), bn(n1), Scalar(1), -1, -1, 0, 0); + + // We are done with the last task for the [m1, n1] block. + if (k + 1 == nk_) { + output_kernel_(output_mapper, tensor_contraction_params_, + m1 * bm_, n1 * bn_, bm(m1), bn(n1)); + } } } signal_kernel(m, n, k + 1, false); @@ -747,6 +764,10 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT } #else // EIGEN_USE_SIMPLE_THREAD_POOL + // TODO(ezhulenev): SimpleThreadPool will be removed in the future, and seems + // like it's not worth adding output kernel support here. + static_assert(std::is_same<OutputKernelType, const NoOpOutputKernel>::value, + "SimpleThreadPool does not support contraction output kernels."); template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment> void evalProduct(Scalar* buffer) const { @@ -1065,6 +1086,10 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT } #if defined(EIGEN_VECTORIZE_AVX) && defined(EIGEN_USE_LIBXSMM) + // TODO(ezhulenev): Add support for output kernels and LIBXSMM. + static_assert(std::is_same<OutputKernelType, const NoOpOutputKernel>::value, + "XSMM does not support contraction output kernels."); + template<int Alignment> class ContextXsmm { public: diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h b/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h index 6c237bac3..19e456e19 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h @@ -65,7 +65,7 @@ template<typename Op, typename Dims, typename XprType, template <class> class Ma template<typename XprType> class TensorIndexTupleOp; template<typename ReduceOp, typename Dims, typename XprType> class TensorTupleReducerOp; template<typename Axis, typename LeftXprType, typename RightXprType> class TensorConcatenationOp; -template<typename Dimensions, typename LeftXprType, typename RightXprType> class TensorContractionOp; +template<typename Dimensions, typename LeftXprType, typename RightXprType, typename OutputKernelType> class TensorContractionOp; template<typename TargetType, typename XprType> class TensorConversionOp; template<typename Dimensions, typename InputXprType, typename KernelXprType> class TensorConvolutionOp; template<typename FFT, typename XprType, int FFTDataType, int FFTDirection> class TensorFFTOp; @@ -97,6 +97,8 @@ template<typename XprType> class TensorForcedEvalOp; template<typename ExpressionType, typename DeviceType> class TensorDevice; template<typename Derived, typename Device> struct TensorEvaluator; +class NoOpOutputKernel; + struct DefaultDevice; struct ThreadPoolDevice; struct GpuDevice; diff --git a/unsupported/test/cxx11_tensor_contraction.cpp b/unsupported/test/cxx11_tensor_contraction.cpp index ace97057f..918c96277 100644 --- a/unsupported/test/cxx11_tensor_contraction.cpp +++ b/unsupported/test/cxx11_tensor_contraction.cpp @@ -510,6 +510,55 @@ static void test_const_inputs() VERIFY_IS_APPROX(mat3(1,1), mat1(1,0)*mat2(0,1) + mat1(1,1)*mat2(1,1) + mat1(1,2)*mat2(2,1)); } +// Apply Sqrt to all output elements. +struct SqrtOutputKernel { + template <typename Index, typename Scalar> + EIGEN_ALWAYS_INLINE void operator()( + const OutputKernel::OutputMapper<Index, Scalar>& output_mapper, + const TensorContractionParams&, Index, Index, Index num_rows, + Index num_cols) const { + for (int i = 0; i < num_rows; ++i) { + for (int j = 0; j < num_cols; ++j) { + output_mapper(i, j) = std::sqrt(output_mapper(i, j)); + } + } + } +}; + +template <int DataLayout> +static void test_large_contraction_with_output_kernel() { + Tensor<float, 4, DataLayout> t_left(30, 50, 8, 31); + Tensor<float, 5, DataLayout> t_right(8, 31, 7, 20, 10); + Tensor<float, 5, DataLayout> t_result(30, 50, 7, 20, 10); + + t_left.setRandom(); + t_right.setRandom(); + // Put trash in mat4 to verify contraction clears output memory. + t_result.setRandom(); + + // Add a little offset so that the results won't be close to zero. + t_left += t_left.constant(1.0f); + t_right += t_right.constant(1.0f); + + typedef Map<Eigen::Matrix<float, Dynamic, Dynamic, DataLayout>> MapXf; + MapXf m_left(t_left.data(), 1500, 248); + MapXf m_right(t_right.data(), 248, 1400); + Eigen::Matrix<float, Dynamic, Dynamic, DataLayout> m_result(1500, 1400); + + // this contraction should be equivalent to a single matrix multiplication + Eigen::array<DimPair, 2> dims({{DimPair(2, 0), DimPair(3, 1)}}); + + // compute results by separate methods + t_result = t_left.contract(t_right, dims, SqrtOutputKernel()); + + m_result = m_left * m_right; + + for (size_t i = 0; i < t_result.dimensions().TotalSize(); i++) { + VERIFY(&t_result.data()[i] != &m_result.data()[i]); + VERIFY_IS_APPROX(t_result.data()[i], std::sqrt(m_result.data()[i])); + } +} + void test_cxx11_tensor_contraction() { CALL_SUBTEST(test_evals<ColMajor>()); @@ -542,4 +591,6 @@ void test_cxx11_tensor_contraction() CALL_SUBTEST(test_tensor_product<RowMajor>()); CALL_SUBTEST(test_const_inputs<ColMajor>()); CALL_SUBTEST(test_const_inputs<RowMajor>()); + CALL_SUBTEST(test_large_contraction_with_output_kernel<ColMajor>()); + CALL_SUBTEST(test_large_contraction_with_output_kernel<RowMajor>()); } diff --git a/unsupported/test/cxx11_tensor_thread_pool.cpp b/unsupported/test/cxx11_tensor_thread_pool.cpp index 2ef665f30..ea9d8afdc 100644 --- a/unsupported/test/cxx11_tensor_thread_pool.cpp +++ b/unsupported/test/cxx11_tensor_thread_pool.cpp @@ -232,6 +232,60 @@ void test_multithread_contraction_agrees_with_singlethread() { } } +// Apply Sqrt to all output elements. +struct SqrtOutputKernel { + template <typename Index, typename Scalar> + EIGEN_ALWAYS_INLINE void operator()( + const OutputKernel::OutputMapper<Index, Scalar>& output_mapper, + const TensorContractionParams&, Index, Index, Index num_rows, + Index num_cols) const { + for (int i = 0; i < num_rows; ++i) { + for (int j = 0; j < num_cols; ++j) { + output_mapper(i, j) = std::sqrt(output_mapper(i, j)); + } + } + } +}; + +template <int DataLayout> +static void test_multithread_contraction_with_output_kernel() { + typedef Tensor<float, 1>::DimensionPair DimPair; + + const int num_threads = internal::random<int>(2, 11); + ThreadPool threads(num_threads); + Eigen::ThreadPoolDevice device(&threads, num_threads); + + Tensor<float, 4, DataLayout> t_left(30, 50, 8, 31); + Tensor<float, 5, DataLayout> t_right(8, 31, 7, 20, 10); + Tensor<float, 5, DataLayout> t_result(30, 50, 7, 20, 10); + + t_left.setRandom(); + t_right.setRandom(); + // Put trash in mat4 to verify contraction clears output memory. + t_result.setRandom(); + + // Add a little offset so that the results won't be close to zero. + t_left += t_left.constant(1.0f); + t_right += t_right.constant(1.0f); + + typedef Map<Eigen::Matrix<float, Dynamic, Dynamic, DataLayout>> MapXf; + MapXf m_left(t_left.data(), 1500, 248); + MapXf m_right(t_right.data(), 248, 1400); + Eigen::Matrix<float, Dynamic, Dynamic, DataLayout> m_result(1500, 1400); + + // this contraction should be equivalent to a single matrix multiplication + Eigen::array<DimPair, 2> dims({{DimPair(2, 0), DimPair(3, 1)}}); + + // compute results by separate methods + t_result.device(device) = t_left.contract(t_right, dims, SqrtOutputKernel()); + + m_result = m_left * m_right; + + for (size_t i = 0; i < t_result.dimensions().TotalSize(); i++) { + VERIFY(&t_result.data()[i] != &m_result.data()[i]); + VERIFY_IS_APPROX(t_result.data()[i], std::sqrt(m_result.data()[i])); + } +} template<int DataLayout> void test_full_contraction() { @@ -355,6 +409,8 @@ void test_cxx11_tensor_thread_pool() CALL_SUBTEST_3(test_multithread_contraction_agrees_with_singlethread<ColMajor>()); CALL_SUBTEST_3(test_multithread_contraction_agrees_with_singlethread<RowMajor>()); + CALL_SUBTEST_3(test_multithread_contraction_with_output_kernel<ColMajor>()); + CALL_SUBTEST_3(test_multithread_contraction_with_output_kernel<RowMajor>()); // Exercise various cases that have been problematic in the past. CALL_SUBTEST_4(test_contraction_corner_cases<ColMajor>()); |