diff options
author | Benoit Jacob <jacob.benoit.1@gmail.com> | 2009-10-20 23:25:49 -0400 |
---|---|---|
committer | Benoit Jacob <jacob.benoit.1@gmail.com> | 2009-10-20 23:25:49 -0400 |
commit | c3180b7ffbc98d69764b3c1ab17b36e289f7cf7e (patch) | |
tree | a6744a712b9a226ea8436a75b3c0b5985570f65b /unsupported | |
parent | 471b4d509234cbcbd4a1cd45d48fe10efcc2bcf1 (diff) |
MatrixBase:
* support resize() to same size (nop). The case of FFT was another case where that make one's life far easier.
hope that's ok with you Gael. but indeed, i don't use it in the ReturnByValue stuff.
FFT:
* Support MatrixBase (well, in the case with direct memory access such as Map)
* adapt unit test
Diffstat (limited to 'unsupported')
-rw-r--r-- | unsupported/Eigen/FFT | 32 | ||||
-rw-r--r-- | unsupported/test/FFT.cpp | 133 |
2 files changed, 130 insertions, 35 deletions
diff --git a/unsupported/Eigen/FFT b/unsupported/Eigen/FFT index dc7e85908..fafdb829b 100644 --- a/unsupported/Eigen/FFT +++ b/unsupported/Eigen/FFT @@ -36,7 +36,7 @@ #define DEFAULT_FFT_IMPL ei_fftw_impl #endif -// intel Math Kernel Library: fastest, commerical -- incompatible with Eigen in GPL form +// intel Math Kernel Library: fastest, commercial -- incompatible with Eigen in GPL form #ifdef _MKL_DFTI_H_ // mkl_dfti.h has been included, we can use MKL FFT routines // TODO // #include "src/FFT/ei_imkl_impl.h" @@ -70,6 +70,20 @@ class FFT fwd( &dst[0],&src[0],src.size() ); } + template<typename InputDerived, typename ComplexDerived> + void fwd( MatrixBase<ComplexDerived> & dst, const MatrixBase<InputDerived> & src) + { + EIGEN_STATIC_ASSERT_VECTOR_ONLY(InputDerived) + EIGEN_STATIC_ASSERT_VECTOR_ONLY(ComplexDerived) + EIGEN_STATIC_ASSERT_SAME_VECTOR_SIZE(ComplexDerived,InputDerived) // size at compile-time + EIGEN_STATIC_ASSERT((ei_is_same_type<typename ComplexDerived::Scalar, Complex>::ret), + YOU_MIXED_DIFFERENT_NUMERIC_TYPES__YOU_NEED_TO_USE_THE_CAST_METHOD_OF_MATRIXBASE_TO_CAST_NUMERIC_TYPES_EXPLICITLY) + EIGEN_STATIC_ASSERT(int(InputDerived::Flags)&int(ComplexDerived::Flags)&DirectAccessBit, + THIS_METHOD_IS_ONLY_FOR_EXPRESSIONS_WITH_DIRECT_MEMORY_ACCESS_SUCH_AS_MAP_OR_PLAIN_MATRICES) + dst.derived().resize( src.size() ); + fwd( &dst[0],&src[0],src.size() ); + } + template <typename _Output> void inv( _Output * dst, const Complex * src, int nfft) { @@ -83,8 +97,24 @@ class FFT inv( &dst[0],&src[0],src.size() ); } + template<typename OutputDerived, typename ComplexDerived> + void inv( MatrixBase<OutputDerived> & dst, const MatrixBase<ComplexDerived> & src) + { + EIGEN_STATIC_ASSERT_VECTOR_ONLY(OutputDerived) + EIGEN_STATIC_ASSERT_VECTOR_ONLY(ComplexDerived) + EIGEN_STATIC_ASSERT_SAME_VECTOR_SIZE(ComplexDerived,OutputDerived) // size at compile-time + EIGEN_STATIC_ASSERT((ei_is_same_type<typename ComplexDerived::Scalar, Complex>::ret), + YOU_MIXED_DIFFERENT_NUMERIC_TYPES__YOU_NEED_TO_USE_THE_CAST_METHOD_OF_MATRIXBASE_TO_CAST_NUMERIC_TYPES_EXPLICITLY) + EIGEN_STATIC_ASSERT(int(OutputDerived::Flags)&int(ComplexDerived::Flags)&DirectAccessBit, + THIS_METHOD_IS_ONLY_FOR_EXPRESSIONS_WITH_DIRECT_MEMORY_ACCESS_SUCH_AS_MAP_OR_PLAIN_MATRICES) + dst.derived().resize( src.size() ); + inv( &dst[0],&src[0],src.size() ); + } + // TODO: multi-dimensional FFTs + // TODO: handle Eigen MatrixBase + // ---> i added fwd and inv specializations above + unit test, is this enough? (bjacob) traits_type & traits() {return m_traits;} private: diff --git a/unsupported/test/FFT.cpp b/unsupported/test/FFT.cpp index f0b9b68bf..cc68f3718 100644 --- a/unsupported/test/FFT.cpp +++ b/unsupported/test/FFT.cpp @@ -39,16 +39,16 @@ complex<long double> promote(double x) { return complex<long double>( x); } complex<long double> promote(long double x) { return complex<long double>( x); } - template <typename T1,typename T2> - long double fft_rmse( const vector<T1> & fftbuf,const vector<T2> & timebuf) + template <typename VectorType1,typename VectorType2> + long double fft_rmse( const VectorType1 & fftbuf,const VectorType2 & timebuf) { long double totalpower=0; long double difpower=0; cerr <<"idx\ttruth\t\tvalue\t|dif|=\n"; - for (size_t k0=0;k0<fftbuf.size();++k0) { + for (size_t k0=0;k0<size_t(fftbuf.size());++k0) { complex<long double> acc = 0; long double phinc = -2.*k0* M_PIl / timebuf.size(); - for (size_t k1=0;k1<timebuf.size();++k1) { + for (size_t k1=0;k1<size_t(timebuf.size());++k1) { acc += promote( timebuf[k1] ) * exp( complex<long double>(0,k1*phinc) ); } totalpower += norm(acc); @@ -61,8 +61,8 @@ complex<long double> promote(long double x) { return complex<long double>( x); return sqrt(difpower/totalpower); } - template <typename T1,typename T2> - long double dif_rmse( const vector<T1> buf1,const vector<T2> buf2) + template <typename VectorType1,typename VectorType2> + long double dif_rmse( const VectorType1& buf1,const VectorType2& buf2) { long double totalpower=0; long double difpower=0; @@ -74,35 +74,59 @@ complex<long double> promote(long double x) { return complex<long double>( x); return sqrt(difpower/totalpower); } -template <class T> -void test_scalar(int nfft) +enum { StdVectorContainer, EigenVectorContainer }; + +template<int Container, typename Scalar> struct VectorType; + +template<typename Scalar> struct VectorType<StdVectorContainer,Scalar> +{ + typedef vector<Scalar> type; +}; + +template<typename Scalar> struct VectorType<EigenVectorContainer,Scalar> { - typedef typename Eigen::FFT<T>::Complex Complex; - typedef typename Eigen::FFT<T>::Scalar Scalar; + typedef Matrix<Scalar,Dynamic,1> type; +}; + +template <int Container, typename T> +void test_scalar_generic(int nfft) +{ + typedef typename FFT<T>::Complex Complex; + typedef typename FFT<T>::Scalar Scalar; + typedef typename VectorType<Container,Scalar>::type ScalarVector; + typedef typename VectorType<Container,Complex>::type ComplexVector; FFT<T> fft; - vector<Scalar> inbuf(nfft); - vector<Complex> outbuf; + ScalarVector inbuf(nfft); + ComplexVector outbuf; for (int k=0;k<nfft;++k) inbuf[k]= (T)(rand()/(double)RAND_MAX - .5); fft.fwd( outbuf,inbuf); VERIFY( fft_rmse(outbuf,inbuf) < test_precision<T>() );// gross check - vector<Scalar> buf3; + ScalarVector buf3; fft.inv( buf3 , outbuf); VERIFY( dif_rmse(inbuf,buf3) < test_precision<T>() );// gross check } -template <class T> -void test_complex(int nfft) +template <typename T> +void test_scalar(int nfft) { - typedef typename Eigen::FFT<T>::Complex Complex; + test_scalar_generic<StdVectorContainer,T>(nfft); + test_scalar_generic<EigenVectorContainer,T>(nfft); +} + +template <int Container, typename T> +void test_complex_generic(int nfft) +{ + typedef typename FFT<T>::Complex Complex; + typedef typename VectorType<Container,Complex>::type ComplexVector; FFT<T> fft; - vector<Complex> inbuf(nfft); - vector<Complex> outbuf; - vector<Complex> buf3; + ComplexVector inbuf(nfft); + ComplexVector outbuf; + ComplexVector buf3; for (int k=0;k<nfft;++k) inbuf[k]= Complex( (T)(rand()/(double)RAND_MAX - .5), (T)(rand()/(double)RAND_MAX - .5) ); fft.fwd( outbuf , inbuf); @@ -114,22 +138,63 @@ void test_complex(int nfft) VERIFY( dif_rmse(inbuf,buf3) < test_precision<T>() );// gross check } -void test_FFT() +template <typename T> +void test_complex(int nfft) { + test_complex_generic<StdVectorContainer,T>(nfft); + test_complex_generic<EigenVectorContainer,T>(nfft); +} - CALL_SUBTEST( test_complex<float>(32) ); CALL_SUBTEST( test_complex<double>(32) ); CALL_SUBTEST( test_complex<long double>(32) ); - CALL_SUBTEST( test_complex<float>(256) ); CALL_SUBTEST( test_complex<double>(256) ); CALL_SUBTEST( test_complex<long double>(256) ); - CALL_SUBTEST( test_complex<float>(3*8) ); CALL_SUBTEST( test_complex<double>(3*8) ); CALL_SUBTEST( test_complex<long double>(3*8) ); - CALL_SUBTEST( test_complex<float>(5*32) ); CALL_SUBTEST( test_complex<double>(5*32) ); CALL_SUBTEST( test_complex<long double>(5*32) ); - CALL_SUBTEST( test_complex<float>(2*3*4) ); CALL_SUBTEST( test_complex<double>(2*3*4) ); CALL_SUBTEST( test_complex<long double>(2*3*4) ); - CALL_SUBTEST( test_complex<float>(2*3*4*5) ); CALL_SUBTEST( test_complex<double>(2*3*4*5) ); CALL_SUBTEST( test_complex<long double>(2*3*4*5) ); - CALL_SUBTEST( test_complex<float>(2*3*4*5*7) ); CALL_SUBTEST( test_complex<double>(2*3*4*5*7) ); CALL_SUBTEST( test_complex<long double>(2*3*4*5*7) ); - - +void test_FFT() +{ - CALL_SUBTEST( test_scalar<float>(32) ); CALL_SUBTEST( test_scalar<double>(32) ); CALL_SUBTEST( test_scalar<long double>(32) ); - CALL_SUBTEST( test_scalar<float>(45) ); CALL_SUBTEST( test_scalar<double>(45) ); CALL_SUBTEST( test_scalar<long double>(45) ); - CALL_SUBTEST( test_scalar<float>(50) ); CALL_SUBTEST( test_scalar<double>(50) ); CALL_SUBTEST( test_scalar<long double>(50) ); - CALL_SUBTEST( test_scalar<float>(256) ); CALL_SUBTEST( test_scalar<double>(256) ); CALL_SUBTEST( test_scalar<long double>(256) ); - CALL_SUBTEST( test_scalar<float>(2*3*4*5*7) ); CALL_SUBTEST( test_scalar<double>(2*3*4*5*7) ); CALL_SUBTEST( test_scalar<long double>(2*3*4*5*7) ); + CALL_SUBTEST( test_complex<float>(32) ); + CALL_SUBTEST( test_complex<double>(32) ); + CALL_SUBTEST( test_complex<long double>(32) ); + + CALL_SUBTEST( test_complex<float>(256) ); + CALL_SUBTEST( test_complex<double>(256) ); + CALL_SUBTEST( test_complex<long double>(256) ); + + CALL_SUBTEST( test_complex<float>(3*8) ); + CALL_SUBTEST( test_complex<double>(3*8) ); + CALL_SUBTEST( test_complex<long double>(3*8) ); + + CALL_SUBTEST( test_complex<float>(5*32) ); + CALL_SUBTEST( test_complex<double>(5*32) ); + CALL_SUBTEST( test_complex<long double>(5*32) ); + + CALL_SUBTEST( test_complex<float>(2*3*4) ); + CALL_SUBTEST( test_complex<double>(2*3*4) ); + CALL_SUBTEST( test_complex<long double>(2*3*4) ); + + CALL_SUBTEST( test_complex<float>(2*3*4*5) ); + CALL_SUBTEST( test_complex<double>(2*3*4*5) ); + CALL_SUBTEST( test_complex<long double>(2*3*4*5) ); + + CALL_SUBTEST( test_complex<float>(2*3*4*5*7) ); + CALL_SUBTEST( test_complex<double>(2*3*4*5*7) ); + CALL_SUBTEST( test_complex<long double>(2*3*4*5*7) ); + + + + CALL_SUBTEST( test_scalar<float>(32) ); + CALL_SUBTEST( test_scalar<double>(32) ); + CALL_SUBTEST( test_scalar<long double>(32) ); + + CALL_SUBTEST( test_scalar<float>(45) ); + CALL_SUBTEST( test_scalar<double>(45) ); + CALL_SUBTEST( test_scalar<long double>(45) ); + + CALL_SUBTEST( test_scalar<float>(50) ); + CALL_SUBTEST( test_scalar<double>(50) ); + CALL_SUBTEST( test_scalar<long double>(50) ); + + CALL_SUBTEST( test_scalar<float>(256) ); + CALL_SUBTEST( test_scalar<double>(256) ); + CALL_SUBTEST( test_scalar<long double>(256) ); + + CALL_SUBTEST( test_scalar<float>(2*3*4*5*7) ); + CALL_SUBTEST( test_scalar<double>(2*3*4*5*7) ); + CALL_SUBTEST( test_scalar<long double>(2*3*4*5*7) ); } |