diff options
author | RJ Ryan <rryan@mixxx.org> | 2017-12-31 10:44:56 -0500 |
---|---|---|
committer | RJ Ryan <rryan@mixxx.org> | 2017-12-31 10:44:56 -0500 |
commit | 59985cfd26416fb6b196af868c187e90d237c352 (patch) | |
tree | 66cb438384d1b79cbdb5dca768ad6e3cf2f1ab14 /unsupported | |
parent | f9bdcea022e24bac4a66a937c37de92f7f22b9da (diff) |
Disable use of recurrence for computing twiddle factors. Fixes FFT precision issues for large FFTs. https://github.com/tensorflow/tensorflow/issues/10749#issuecomment-354557689
Diffstat (limited to 'unsupported')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorFFT.h | 40 | ||||
-rw-r--r-- | unsupported/test/cxx11_tensor_fft.cpp | 28 |
2 files changed, 54 insertions, 14 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorFFT.h b/unsupported/Eigen/CXX11/src/Tensor/TensorFFT.h index 10e0a8a6b..f81da318c 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorFFT.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorFFT.h @@ -231,20 +231,32 @@ struct TensorEvaluator<const TensorFFTOp<FFT, ArgType, FFTResultType, FFTDir>, D // t_n = exp(sqrt(-1) * pi * n^2 / line_len) // for n = 0, 1,..., line_len-1. // For n > 2 we use the recurrence t_n = t_{n-1}^2 / t_{n-2} * t_1^2 - pos_j_base_powered[0] = ComplexScalar(1, 0); - if (line_len > 1) { - const RealScalar pi_over_len(EIGEN_PI / line_len); - const ComplexScalar pos_j_base = ComplexScalar( - std::cos(pi_over_len), std::sin(pi_over_len)); - pos_j_base_powered[1] = pos_j_base; - if (line_len > 2) { - const ComplexScalar pos_j_base_sq = pos_j_base * pos_j_base; - for (int j = 2; j < line_len + 1; ++j) { - pos_j_base_powered[j] = pos_j_base_powered[j - 1] * - pos_j_base_powered[j - 1] / - pos_j_base_powered[j - 2] * pos_j_base_sq; - } - } + + // The recurrence is correct in exact arithmetic, but causes + // numerical issues for large transforms, especially in + // single-precision floating point. + // + // pos_j_base_powered[0] = ComplexScalar(1, 0); + // if (line_len > 1) { + // const ComplexScalar pos_j_base = ComplexScalar( + // numext::cos(M_PI / line_len), numext::sin(M_PI / line_len)); + // pos_j_base_powered[1] = pos_j_base; + // if (line_len > 2) { + // const ComplexScalar pos_j_base_sq = pos_j_base * pos_j_base; + // for (int i = 2; i < line_len + 1; ++i) { + // pos_j_base_powered[i] = pos_j_base_powered[i - 1] * + // pos_j_base_powered[i - 1] / + // pos_j_base_powered[i - 2] * + // pos_j_base_sq; + // } + // } + // } + // TODO(rmlarsen): Find a way to use Eigen's vectorized sin + // and cosine functions here. + for (int j = 0; j < line_len + 1; ++j) { + double arg = ((EIGEN_PI * j) * j) / line_len; + std::complex<double> tmp(numext::cos(arg), numext::sin(arg)); + pos_j_base_powered[j] = static_cast<ComplexScalar>(tmp); } } diff --git a/unsupported/test/cxx11_tensor_fft.cpp b/unsupported/test/cxx11_tensor_fft.cpp index 2f14ebc62..a55369477 100644 --- a/unsupported/test/cxx11_tensor_fft.cpp +++ b/unsupported/test/cxx11_tensor_fft.cpp @@ -224,6 +224,32 @@ static void test_fft_real_input_energy() { } } +template <typename RealScalar> +static void test_fft_non_power_of_2_round_trip(int exponent) { + int n = (1 << exponent) + 1; + + Eigen::DSizes<long, 1> dimensions; + dimensions[0] = n; + const DSizes<long, 1> arr = dimensions; + Tensor<RealScalar, 1, ColMajor, long> input; + + input.resize(arr); + input.setRandom(); + + array<int, 1> fft; + fft[0] = 0; + + Tensor<std::complex<RealScalar>, 1, ColMajor> forward = + input.template fft<BothParts, FFT_FORWARD>(fft); + + Tensor<RealScalar, 1, ColMajor, long> output = + forward.template fft<RealPart, FFT_REVERSE>(fft); + + for (int i = 0; i < n; ++i) { + VERIFY_IS_APPROX(input[i], output[i]); + } +} + void test_cxx11_tensor_fft() { test_fft_complex_input_golden(); test_fft_real_input_golden(); @@ -270,4 +296,6 @@ void test_cxx11_tensor_fft() { test_fft_real_input_energy<RowMajor, double, true, Eigen::BothParts, FFT_FORWARD, 4>(); test_fft_real_input_energy<RowMajor, float, false, Eigen::BothParts, FFT_FORWARD, 4>(); test_fft_real_input_energy<RowMajor, double, false, Eigen::BothParts, FFT_FORWARD, 4>(); + + test_fft_non_power_of_2_round_trip<float>(7); } |