diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorFFT.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorFFT.h | 407 |
1 files changed, 227 insertions, 180 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorFFT.h b/unsupported/Eigen/CXX11/src/Tensor/TensorFFT.h index 215a4ebad..d6db45ade 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorFFT.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorFFT.h @@ -10,8 +10,9 @@ #ifndef EIGEN_CXX11_TENSOR_TENSOR_FFT_H #define EIGEN_CXX11_TENSOR_TENSOR_FFT_H -// NVCC fails to compile this code -#if !defined(__CUDACC__) +// This code requires the ability to initialize arrays of constant +// values directly inside a class. +#if __cplusplus >= 201103L || EIGEN_COMP_MSVC >= 1900 namespace Eigen { @@ -135,6 +136,7 @@ struct TensorEvaluator<const TensorFFTOp<FFT, ArgType, FFTResultType, FFTDir>, D BlockAccess = false, Layout = TensorEvaluator<ArgType, Device>::Layout, CoordAccess = false, + RawAccess = false }; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) : m_fft(op.fft()), m_impl(op.expression(), device), m_data(NULL), m_device(device) { @@ -205,7 +207,7 @@ struct TensorEvaluator<const TensorFFTOp<FFT, ArgType, FFTResultType, FFTDir>, D } for (size_t i = 0; i < m_fft.size(); ++i) { - int dim = m_fft[i]; + Index dim = m_fft[i]; eigen_assert(dim >= 0 && dim < NumDims); Index line_len = m_dimensions[dim]; eigen_assert(line_len >= 1); @@ -218,19 +220,39 @@ struct TensorEvaluator<const TensorFFTOp<FFT, ArgType, FFTResultType, FFTDir>, D ComplexScalar* b = is_power_of_two ? NULL : (ComplexScalar*)m_device.allocate(sizeof(ComplexScalar) * good_composite); ComplexScalar* pos_j_base_powered = is_power_of_two ? NULL : (ComplexScalar*)m_device.allocate(sizeof(ComplexScalar) * (line_len + 1)); if (!is_power_of_two) { - ComplexScalar pos_j_base = ComplexScalar(std::cos(M_PI/line_len), std::sin(M_PI/line_len)); - for (Index j = 0; j < line_len + 1; ++j) { - pos_j_base_powered[j] = std::pow(pos_j_base, j * j); + // Compute twiddle factors + // t_n = exp(sqrt(-1) * pi * n^2 / line_len) + // for n = 0, 1,..., line_len-1. + // For n > 2 we use the recurrence t_n = t_{n-1}^2 / t_{n-2} * t_1^2 + pos_j_base_powered[0] = ComplexScalar(1, 0); + if (line_len > 1) { + const RealScalar pi_over_len(EIGEN_PI / line_len); + const ComplexScalar pos_j_base = ComplexScalar( + std::cos(pi_over_len), std::sin(pi_over_len)); + pos_j_base_powered[1] = pos_j_base; + if (line_len > 2) { + const ComplexScalar pos_j_base_sq = pos_j_base * pos_j_base; + for (int j = 2; j < line_len + 1; ++j) { + pos_j_base_powered[j] = pos_j_base_powered[j - 1] * + pos_j_base_powered[j - 1] / + pos_j_base_powered[j - 2] * pos_j_base_sq; + } + } } } for (Index partial_index = 0; partial_index < m_size / line_len; ++partial_index) { - Index base_offset = getBaseOffsetFromIndex(partial_index, dim); + const Index base_offset = getBaseOffsetFromIndex(partial_index, dim); // get data into line_buf - for (Index j = 0; j < line_len; ++j) { - Index offset = getIndexFromOffset(base_offset, dim, j); - line_buf[j] = buf[offset]; + const Index stride = m_strides[dim]; + if (stride == 1) { + memcpy(line_buf, &buf[base_offset], line_len*sizeof(ComplexScalar)); + } else { + Index offset = base_offset; + for (int j = 0; j < line_len; ++j, offset += stride) { + line_buf[j] = buf[offset]; + } } // processs the line @@ -242,14 +264,18 @@ struct TensorEvaluator<const TensorFFTOp<FFT, ArgType, FFTResultType, FFTDir>, D } // write back - for (Index j = 0; j < line_len; ++j) { - const ComplexScalar div_factor = (FFTDir == FFT_FORWARD) ? ComplexScalar(1, 0) : ComplexScalar(line_len, 0); - Index offset = getIndexFromOffset(base_offset, dim, j); - buf[offset] = line_buf[j] / div_factor; + if (FFTDir == FFT_FORWARD && stride == 1) { + memcpy(&buf[base_offset], line_buf, line_len*sizeof(ComplexScalar)); + } else { + Index offset = base_offset; + const ComplexScalar div_factor = ComplexScalar(1.0 / line_len, 0); + for (int j = 0; j < line_len; ++j, offset += stride) { + buf[offset] = (FFTDir == FFT_FORWARD) ? line_buf[j] : line_buf[j] * div_factor; + } } } m_device.deallocate(line_buf); - if (!pos_j_base_powered) { + if (!is_power_of_two) { m_device.deallocate(a); m_device.deallocate(b); m_device.deallocate(pos_j_base_powered); @@ -371,109 +397,130 @@ struct TensorEvaluator<const TensorFFTOp<FFT, ArgType, FFTResultType, FFTDir>, D } } - template<int Dir> - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void compute_1D_Butterfly(ComplexScalar* data, Index n, Index n_power_of_2) { - eigen_assert(isPowerOfTwo(n)); - if (n == 1) { - return; + template <int Dir> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void butterfly_2(ComplexScalar* data) { + ComplexScalar tmp = data[1]; + data[1] = data[0] - data[1]; + data[0] += tmp; + } + + template <int Dir> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void butterfly_4(ComplexScalar* data) { + ComplexScalar tmp[4]; + tmp[0] = data[0] + data[1]; + tmp[1] = data[0] - data[1]; + tmp[2] = data[2] + data[3]; + if (Dir == FFT_FORWARD) { + tmp[3] = ComplexScalar(0.0, -1.0) * (data[2] - data[3]); + } else { + tmp[3] = ComplexScalar(0.0, 1.0) * (data[2] - data[3]); } - else if (n == 2) { - ComplexScalar tmp = data[1]; - data[1] = data[0] - data[1]; - data[0] += tmp; - return; + data[0] = tmp[0] + tmp[2]; + data[1] = tmp[1] + tmp[3]; + data[2] = tmp[0] - tmp[2]; + data[3] = tmp[1] - tmp[3]; + } + + template <int Dir> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void butterfly_8(ComplexScalar* data) { + ComplexScalar tmp_1[8]; + ComplexScalar tmp_2[8]; + + tmp_1[0] = data[0] + data[1]; + tmp_1[1] = data[0] - data[1]; + tmp_1[2] = data[2] + data[3]; + if (Dir == FFT_FORWARD) { + tmp_1[3] = (data[2] - data[3]) * ComplexScalar(0, -1); + } else { + tmp_1[3] = (data[2] - data[3]) * ComplexScalar(0, 1); } - else if (n == 4) { - ComplexScalar tmp[4]; - tmp[0] = data[0] + data[1]; - tmp[1] = data[0] - data[1]; - tmp[2] = data[2] + data[3]; - if(Dir == FFT_FORWARD) { - tmp[3] = ComplexScalar(0.0, -1.0) * (data[2] - data[3]); - } - else { - tmp[3] = ComplexScalar(0.0, 1.0) * (data[2] - data[3]); - } - data[0] = tmp[0] + tmp[2]; - data[1] = tmp[1] + tmp[3]; - data[2] = tmp[0] - tmp[2]; - data[3] = tmp[1] - tmp[3]; - return; + tmp_1[4] = data[4] + data[5]; + tmp_1[5] = data[4] - data[5]; + tmp_1[6] = data[6] + data[7]; + if (Dir == FFT_FORWARD) { + tmp_1[7] = (data[6] - data[7]) * ComplexScalar(0, -1); + } else { + tmp_1[7] = (data[6] - data[7]) * ComplexScalar(0, 1); } - else if (n == 8) { - ComplexScalar tmp_1[8]; - ComplexScalar tmp_2[8]; - - tmp_1[0] = data[0] + data[1]; - tmp_1[1] = data[0] - data[1]; - tmp_1[2] = data[2] + data[3]; - if (Dir == FFT_FORWARD) { - tmp_1[3] = (data[2] - data[3]) * ComplexScalar(0, -1); - } - else { - tmp_1[3] = (data[2] - data[3]) * ComplexScalar(0, 1); - } - tmp_1[4] = data[4] + data[5]; - tmp_1[5] = data[4] - data[5]; - tmp_1[6] = data[6] + data[7]; - if (Dir == FFT_FORWARD) { - tmp_1[7] = (data[6] - data[7]) * ComplexScalar(0, -1); - } - else { - tmp_1[7] = (data[6] - data[7]) * ComplexScalar(0, 1); - } - tmp_2[0] = tmp_1[0] + tmp_1[2]; - tmp_2[1] = tmp_1[1] + tmp_1[3]; - tmp_2[2] = tmp_1[0] - tmp_1[2]; - tmp_2[3] = tmp_1[1] - tmp_1[3]; - tmp_2[4] = tmp_1[4] + tmp_1[6]; - // SQRT2DIV2 = sqrt(2)/2 - #define SQRT2DIV2 0.7071067811865476 - if (Dir == FFT_FORWARD) { - tmp_2[5] = (tmp_1[5] + tmp_1[7]) * ComplexScalar(SQRT2DIV2, -SQRT2DIV2); - tmp_2[6] = (tmp_1[4] - tmp_1[6]) * ComplexScalar(0, -1); - tmp_2[7] = (tmp_1[5] - tmp_1[7]) * ComplexScalar(-SQRT2DIV2, -SQRT2DIV2); - } - else { - tmp_2[5] = (tmp_1[5] + tmp_1[7]) * ComplexScalar(SQRT2DIV2, SQRT2DIV2); - tmp_2[6] = (tmp_1[4] - tmp_1[6]) * ComplexScalar(0, 1); - tmp_2[7] = (tmp_1[5] - tmp_1[7]) * ComplexScalar(-SQRT2DIV2, SQRT2DIV2); - } - data[0] = tmp_2[0] + tmp_2[4]; - data[1] = tmp_2[1] + tmp_2[5]; - data[2] = tmp_2[2] + tmp_2[6]; - data[3] = tmp_2[3] + tmp_2[7]; - data[4] = tmp_2[0] - tmp_2[4]; - data[5] = tmp_2[1] - tmp_2[5]; - data[6] = tmp_2[2] - tmp_2[6]; - data[7] = tmp_2[3] - tmp_2[7]; - - return; + tmp_2[0] = tmp_1[0] + tmp_1[2]; + tmp_2[1] = tmp_1[1] + tmp_1[3]; + tmp_2[2] = tmp_1[0] - tmp_1[2]; + tmp_2[3] = tmp_1[1] - tmp_1[3]; + tmp_2[4] = tmp_1[4] + tmp_1[6]; +// SQRT2DIV2 = sqrt(2)/2 +#define SQRT2DIV2 0.7071067811865476 + if (Dir == FFT_FORWARD) { + tmp_2[5] = (tmp_1[5] + tmp_1[7]) * ComplexScalar(SQRT2DIV2, -SQRT2DIV2); + tmp_2[6] = (tmp_1[4] - tmp_1[6]) * ComplexScalar(0, -1); + tmp_2[7] = (tmp_1[5] - tmp_1[7]) * ComplexScalar(-SQRT2DIV2, -SQRT2DIV2); + } else { + tmp_2[5] = (tmp_1[5] + tmp_1[7]) * ComplexScalar(SQRT2DIV2, SQRT2DIV2); + tmp_2[6] = (tmp_1[4] - tmp_1[6]) * ComplexScalar(0, 1); + tmp_2[7] = (tmp_1[5] - tmp_1[7]) * ComplexScalar(-SQRT2DIV2, SQRT2DIV2); } - else { - compute_1D_Butterfly<Dir>(data, n/2, n_power_of_2 - 1); - compute_1D_Butterfly<Dir>(data + n/2, n/2, n_power_of_2 - 1); - //Original code: - //RealScalar wtemp = std::sin(M_PI/n); - //RealScalar wpi = -std::sin(2 * M_PI/n); - RealScalar wtemp = m_sin_PI_div_n_LUT[n_power_of_2]; - RealScalar wpi; - if (Dir == FFT_FORWARD) { - wpi = m_minus_sin_2_PI_div_n_LUT[n_power_of_2]; - } - else { - wpi = 0 - m_minus_sin_2_PI_div_n_LUT[n_power_of_2]; - } + data[0] = tmp_2[0] + tmp_2[4]; + data[1] = tmp_2[1] + tmp_2[5]; + data[2] = tmp_2[2] + tmp_2[6]; + data[3] = tmp_2[3] + tmp_2[7]; + data[4] = tmp_2[0] - tmp_2[4]; + data[5] = tmp_2[1] - tmp_2[5]; + data[6] = tmp_2[2] - tmp_2[6]; + data[7] = tmp_2[3] - tmp_2[7]; + } - const ComplexScalar wp(wtemp, wpi); - ComplexScalar w(1.0, 0.0); - for(Index i = 0; i < n/2; i++) { - ComplexScalar temp(data[i + n/2] * w); - data[i + n/2] = data[i] - temp; - data[i] += temp; - w += w * wp; - } - return; + template <int Dir> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void butterfly_1D_merge( + ComplexScalar* data, Index n, Index n_power_of_2) { + // Original code: + // RealScalar wtemp = std::sin(M_PI/n); + // RealScalar wpi = -std::sin(2 * M_PI/n); + const RealScalar wtemp = m_sin_PI_div_n_LUT[n_power_of_2]; + const RealScalar wpi = (Dir == FFT_FORWARD) + ? m_minus_sin_2_PI_div_n_LUT[n_power_of_2] + : -m_minus_sin_2_PI_div_n_LUT[n_power_of_2]; + + const ComplexScalar wp(wtemp, wpi); + const ComplexScalar wp_one = wp + ComplexScalar(1, 0); + const ComplexScalar wp_one_2 = wp_one * wp_one; + const ComplexScalar wp_one_3 = wp_one_2 * wp_one; + const ComplexScalar wp_one_4 = wp_one_3 * wp_one; + const Index n2 = n / 2; + ComplexScalar w(1.0, 0.0); + for (Index i = 0; i < n2; i += 4) { + ComplexScalar temp0(data[i + n2] * w); + ComplexScalar temp1(data[i + 1 + n2] * w * wp_one); + ComplexScalar temp2(data[i + 2 + n2] * w * wp_one_2); + ComplexScalar temp3(data[i + 3 + n2] * w * wp_one_3); + w = w * wp_one_4; + + data[i + n2] = data[i] - temp0; + data[i] += temp0; + + data[i + 1 + n2] = data[i + 1] - temp1; + data[i + 1] += temp1; + + data[i + 2 + n2] = data[i + 2] - temp2; + data[i + 2] += temp2; + + data[i + 3 + n2] = data[i + 3] - temp3; + data[i + 3] += temp3; + } + } + + template <int Dir> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void compute_1D_Butterfly( + ComplexScalar* data, Index n, Index n_power_of_2) { + eigen_assert(isPowerOfTwo(n)); + if (n > 8) { + compute_1D_Butterfly<Dir>(data, n / 2, n_power_of_2 - 1); + compute_1D_Butterfly<Dir>(data + n / 2, n / 2, n_power_of_2 - 1); + butterfly_1D_merge<Dir>(data, n, n_power_of_2); + } else if (n == 8) { + butterfly_8<Dir>(data); + } else if (n == 4) { + butterfly_4<Dir>(data); + } else if (n == 2) { + butterfly_2<Dir>(data); } } @@ -518,81 +565,81 @@ struct TensorEvaluator<const TensorFFTOp<FFT, ArgType, FFTResultType, FFTDir>, D // This will support a maximum FFT size of 2^32 for each dimension // m_sin_PI_div_n_LUT[i] = (-2) * std::sin(M_PI / std::pow(2,i)) ^ 2; - RealScalar m_sin_PI_div_n_LUT[32] = { - 0.0, - -2, - -0.999999999999999, - -0.292893218813453, - -0.0761204674887130, - -0.0192147195967696, - -0.00481527332780311, - -0.00120454379482761, - -3.01181303795779e-04, - -7.52981608554592e-05, - -1.88247173988574e-05, - -4.70619042382852e-06, - -1.17654829809007e-06, - -2.94137117780840e-07, - -7.35342821488550e-08, - -1.83835707061916e-08, - -4.59589268710903e-09, - -1.14897317243732e-09, - -2.87243293150586e-10, - -7.18108232902250e-11, - -1.79527058227174e-11, - -4.48817645568941e-12, - -1.12204411392298e-12, - -2.80511028480785e-13, - -7.01277571201985e-14, - -1.75319392800498e-14, - -4.38298482001247e-15, - -1.09574620500312e-15, - -2.73936551250781e-16, - -6.84841378126949e-17, - -1.71210344531737e-17, - -4.28025861329343e-18 + const RealScalar m_sin_PI_div_n_LUT[32] = { + RealScalar(0.0), + RealScalar(-2), + RealScalar(-0.999999999999999), + RealScalar(-0.292893218813453), + RealScalar(-0.0761204674887130), + RealScalar(-0.0192147195967696), + RealScalar(-0.00481527332780311), + RealScalar(-0.00120454379482761), + RealScalar(-3.01181303795779e-04), + RealScalar(-7.52981608554592e-05), + RealScalar(-1.88247173988574e-05), + RealScalar(-4.70619042382852e-06), + RealScalar(-1.17654829809007e-06), + RealScalar(-2.94137117780840e-07), + RealScalar(-7.35342821488550e-08), + RealScalar(-1.83835707061916e-08), + RealScalar(-4.59589268710903e-09), + RealScalar(-1.14897317243732e-09), + RealScalar(-2.87243293150586e-10), + RealScalar( -7.18108232902250e-11), + RealScalar(-1.79527058227174e-11), + RealScalar(-4.48817645568941e-12), + RealScalar(-1.12204411392298e-12), + RealScalar(-2.80511028480785e-13), + RealScalar(-7.01277571201985e-14), + RealScalar(-1.75319392800498e-14), + RealScalar(-4.38298482001247e-15), + RealScalar(-1.09574620500312e-15), + RealScalar(-2.73936551250781e-16), + RealScalar(-6.84841378126949e-17), + RealScalar(-1.71210344531737e-17), + RealScalar(-4.28025861329343e-18) }; // m_minus_sin_2_PI_div_n_LUT[i] = -std::sin(2 * M_PI / std::pow(2,i)); - RealScalar m_minus_sin_2_PI_div_n_LUT[32] = { - 0.0, - 0.0, - -1.00000000000000e+00, - -7.07106781186547e-01, - -3.82683432365090e-01, - -1.95090322016128e-01, - -9.80171403295606e-02, - -4.90676743274180e-02, - -2.45412285229123e-02, - -1.22715382857199e-02, - -6.13588464915448e-03, - -3.06795676296598e-03, - -1.53398018628477e-03, - -7.66990318742704e-04, - -3.83495187571396e-04, - -1.91747597310703e-04, - -9.58737990959773e-05, - -4.79368996030669e-05, - -2.39684498084182e-05, - -1.19842249050697e-05, - -5.99211245264243e-06, - -2.99605622633466e-06, - -1.49802811316901e-06, - -7.49014056584716e-07, - -3.74507028292384e-07, - -1.87253514146195e-07, - -9.36267570730981e-08, - -4.68133785365491e-08, - -2.34066892682746e-08, - -1.17033446341373e-08, - -5.85167231706864e-09, - -2.92583615853432e-09 + const RealScalar m_minus_sin_2_PI_div_n_LUT[32] = { + RealScalar(0.0), + RealScalar(0.0), + RealScalar(-1.00000000000000e+00), + RealScalar(-7.07106781186547e-01), + RealScalar(-3.82683432365090e-01), + RealScalar(-1.95090322016128e-01), + RealScalar(-9.80171403295606e-02), + RealScalar(-4.90676743274180e-02), + RealScalar(-2.45412285229123e-02), + RealScalar(-1.22715382857199e-02), + RealScalar(-6.13588464915448e-03), + RealScalar(-3.06795676296598e-03), + RealScalar(-1.53398018628477e-03), + RealScalar(-7.66990318742704e-04), + RealScalar(-3.83495187571396e-04), + RealScalar(-1.91747597310703e-04), + RealScalar(-9.58737990959773e-05), + RealScalar(-4.79368996030669e-05), + RealScalar(-2.39684498084182e-05), + RealScalar(-1.19842249050697e-05), + RealScalar(-5.99211245264243e-06), + RealScalar(-2.99605622633466e-06), + RealScalar(-1.49802811316901e-06), + RealScalar(-7.49014056584716e-07), + RealScalar(-3.74507028292384e-07), + RealScalar(-1.87253514146195e-07), + RealScalar(-9.36267570730981e-08), + RealScalar(-4.68133785365491e-08), + RealScalar(-2.34066892682746e-08), + RealScalar(-1.17033446341373e-08), + RealScalar(-5.85167231706864e-09), + RealScalar(-2.92583615853432e-09) }; }; } // end namespace Eigen -#endif // __CUDACC__ +#endif // EIGEN_HAS_CONSTEXPR #endif // EIGEN_CXX11_TENSOR_TENSOR_FFT_H |