// This file is part of Eigen, a lightweight C++ template library // for linear algebra. // // Copyright (C) 2014 Benoit Steiner // Copyright (C) 2015 Jianwei Cui // // This Source Code Form is subject to the terms of the Mozilla // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONVOLUTIONBYFFT_H #define EIGEN_CXX11_TENSOR_TENSOR_CONVOLUTIONBYFFT_H namespace Eigen { /** \class TensorConvolutionByFFT * \ingroup CXX11_Tensor_Module * * \brief Tensor convolution class. * * */ namespace internal { template struct traits > { // Type promotion to handle the case where the types of the lhs and the rhs are different. typedef typename promote_storage_type::ret Scalar; typedef typename packet_traits::type Packet; typedef typename promote_storage_type::StorageKind, typename traits::StorageKind>::ret StorageKind; typedef typename promote_index_type::Index, typename traits::Index>::type Index; typedef typename InputXprType::Nested LhsNested; typedef typename KernelXprType::Nested RhsNested; typedef typename remove_reference::type _LhsNested; typedef typename remove_reference::type _RhsNested; static const int NumDimensions = traits::NumDimensions; static const int Layout = traits::Layout; enum { Flags = 0, }; }; template struct eval, Eigen::Dense> { typedef const TensorConvolutionByFFTOp& type; }; template struct nested, 1, typename eval >::type> { typedef TensorConvolutionByFFTOp type; }; } // end namespace internal template class TensorConvolutionByFFTOp : public TensorBase > { public: typedef typename Eigen::internal::traits::Scalar Scalar; typedef typename Eigen::internal::traits::Packet Packet; typedef typename Eigen::NumTraits::Real RealScalar; typedef typename internal::promote_storage_type::ret CoeffReturnType; typedef typename internal::promote_storage_type::ret PacketReturnType; typedef typename Eigen::internal::nested::type Nested; typedef typename Eigen::internal::traits::StorageKind StorageKind; typedef typename Eigen::internal::traits::Index Index; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorConvolutionByFFTOp(const InputXprType& input, const KernelXprType& kernel, const Indices& dims) : m_input_xpr(input), m_kernel_xpr(kernel), m_indices(dims) {} EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Indices& indices() const { return m_indices; } /** \returns the nested expressions */ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const typename internal::remove_all::type& inputExpression() const { return m_input_xpr; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const typename internal::remove_all::type& kernelExpression() const { return m_kernel_xpr; } protected: typename InputXprType::Nested m_input_xpr; typename KernelXprType::Nested m_kernel_xpr; const Indices m_indices; }; template struct TensorEvaluator, Device> { typedef TensorConvolutionByFFTOp XprType; typedef typename XprType::Scalar Scalar; typedef typename XprType::CoeffReturnType CoeffReturnType; typedef typename XprType::PacketReturnType PacketReturnType; typedef typename Eigen::NumTraits::Real RealScalar; static const int NumDims = internal::array_size::Dimensions>::value; static const int NumKernelDims = internal::array_size::value; typedef typename XprType::Index Index; typedef DSizes Dimensions; enum { IsAligned = TensorEvaluator::IsAligned & TensorEvaluator::IsAligned, PacketAccess = false, BlockAccess = false, Layout = TensorEvaluator::Layout, CoordAccess = false, // to be implemented }; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) : m_inputImpl(op.inputExpression(), device), m_kernelImpl(op.kernelExpression(), device), m_kernelArg(op.kernelExpression()), m_kernel(NULL), m_local_kernel(false), m_device(device) { EIGEN_STATIC_ASSERT((static_cast(TensorEvaluator::Layout) == static_cast(TensorEvaluator::Layout)), YOU_MADE_A_PROGRAMMING_MISTAKE); const typename TensorEvaluator::Dimensions& input_dims = m_inputImpl.dimensions(); const typename TensorEvaluator::Dimensions& kernel_dims = m_kernelImpl.dimensions(); if (static_cast(Layout) == static_cast(ColMajor)) { m_inputStride[0] = 1; for (int i = 1; i < NumDims; ++i) { m_inputStride[i] = m_inputStride[i - 1] * input_dims[i - 1]; } } else { m_inputStride[NumDims - 1] = 1; for (int i = NumDims - 2; i >= 0; --i) { m_inputStride[i] = m_inputStride[i + 1] * input_dims[i + 1]; } } m_dimensions = m_inputImpl.dimensions(); if (static_cast(Layout) == static_cast(ColMajor)) { for (int i = 0; i < NumKernelDims; ++i) { const Index index = op.indices()[i]; const Index input_dim = input_dims[index]; const Index kernel_dim = kernel_dims[i]; const Index result_dim = input_dim - kernel_dim + 1; m_dimensions[index] = result_dim; if (i > 0) { m_kernelStride[i] = m_kernelStride[i - 1] * kernel_dims[i - 1]; } else { m_kernelStride[0] = 1; } m_indexStride[i] = m_inputStride[index]; } m_outputStride[0] = 1; for (int i = 1; i < NumDims; ++i) { m_outputStride[i] = m_outputStride[i - 1] * m_dimensions[i - 1]; } } else { for (int i = NumKernelDims - 1; i >= 0; --i) { const Index index = op.indices()[i]; const Index input_dim = input_dims[index]; const Index kernel_dim = kernel_dims[i]; const Index result_dim = input_dim - kernel_dim + 1; m_dimensions[index] = result_dim; if (i < NumKernelDims - 1) { m_kernelStride[i] = m_kernelStride[i + 1] * kernel_dims[i + 1]; } else { m_kernelStride[NumKernelDims - 1] = 1; } m_indexStride[i] = m_inputStride[index]; } m_outputStride[NumDims - 1] = 1; for (int i = NumDims - 2; i >= 0; --i) { m_outputStride[i] = m_outputStride[i + 1] * m_dimensions[i + 1]; } } } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* data) { m_inputImpl.evalSubExprsIfNeeded(NULL); m_kernelImpl.evalSubExprsIfNeeded(NULL); typedef typename internal::traits::Index TensorIndex; Tensor input(m_inputImpl.dimensions()); for (int i = 0; i < m_inputImpl.dimensions().TotalSize(); ++i) { input.data()[i] = m_inputImpl.coeff(i); } Tensor kernel(m_kernelImpl.dimensions()); for (int i = 0; i < m_kernelImpl.dimensions().TotalSize(); ++i) { kernel.data()[i] = m_kernelImpl.coeff(i); } array, NumDims> paddings; for (int i = 0; i < NumDims; ++i) { paddings[i] = std::make_pair(0, m_inputImpl.dimensions()[i] - m_kernelImpl.dimensions()[i]); } Eigen::array reverse; for (int i = 0; i < NumKernelDims; ++i) { reverse[i] = true; } Eigen::array fft; for (int i = 0; i < NumDims; ++i) { fft[i] = i; } Eigen::DSizes slice_offsets; for (int i = 0; i < NumDims; ++i) { slice_offsets[i] = m_kernelImpl.dimensions()[i] - 1; } Eigen::DSizes slice_extents; for (int i = 0; i < NumDims; ++i) { slice_extents[i] = m_inputImpl.dimensions()[i] - m_kernelImpl.dimensions()[i] + 1; } Tensor kernel_variant = kernel.reverse(reverse).pad(paddings); Tensor, NumDims, Layout, TensorIndex> kernel_fft = kernel_variant.template fft(fft); //Tensor, NumDims, Layout|IndexType> kernel_fft = kernel.reverse(reverse).pad(paddings).template fft<2>(fft); Tensor, NumDims, Layout, TensorIndex> input_fft = input.template fft(fft); Tensor, NumDims, Layout, TensorIndex> prod = (input_fft * kernel_fft).template fft(fft); Tensor, NumDims, Layout, TensorIndex> tensor_result = prod.slice(slice_offsets, slice_extents); for (int i = 0; i < tensor_result.size(); ++i) { data[i] = std::real(tensor_result.data()[i]); } return false; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { m_inputImpl.cleanup(); if (m_local_kernel) { m_device.deallocate((void*)m_kernel); m_local_kernel = false; } m_kernel = NULL; } void evalTo(typename XprType::Scalar* buffer) { evalSubExprsIfNeeded(NULL); for (int i = 0; i < dimensions().TotalSize(); ++i) { buffer[i] += coeff(i); } cleanup(); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { CoeffReturnType result = CoeffReturnType(0); return result; } EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; } private: array m_inputStride; array m_outputStride; array m_indexStride; array m_kernelStride; TensorEvaluator m_inputImpl; TensorEvaluator m_kernelImpl; Dimensions m_dimensions; KernelArgType m_kernelArg; const Scalar* m_kernel; bool m_local_kernel; const Device& m_device; }; } // end namespace Eigen #endif // EIGEN_CXX11_TENSOR_TENSOR_CONVOLUTIONBYFFT_H