aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported
diff options
context:
space:
mode:
authorGravatar Benoit Jacob <jacob.benoit.1@gmail.com>2009-10-20 23:25:49 -0400
committerGravatar Benoit Jacob <jacob.benoit.1@gmail.com>2009-10-20 23:25:49 -0400
commitc3180b7ffbc98d69764b3c1ab17b36e289f7cf7e (patch)
treea6744a712b9a226ea8436a75b3c0b5985570f65b /unsupported
parent471b4d509234cbcbd4a1cd45d48fe10efcc2bcf1 (diff)
MatrixBase:
* support resize() to same size (nop). The case of FFT was another case where that make one's life far easier. hope that's ok with you Gael. but indeed, i don't use it in the ReturnByValue stuff. FFT: * Support MatrixBase (well, in the case with direct memory access such as Map) * adapt unit test
Diffstat (limited to 'unsupported')
-rw-r--r--unsupported/Eigen/FFT32
-rw-r--r--unsupported/test/FFT.cpp133
2 files changed, 130 insertions, 35 deletions
diff --git a/unsupported/Eigen/FFT b/unsupported/Eigen/FFT
index dc7e85908..fafdb829b 100644
--- a/unsupported/Eigen/FFT
+++ b/unsupported/Eigen/FFT
@@ -36,7 +36,7 @@
#define DEFAULT_FFT_IMPL ei_fftw_impl
#endif
-// intel Math Kernel Library: fastest, commerical -- incompatible with Eigen in GPL form
+// intel Math Kernel Library: fastest, commercial -- incompatible with Eigen in GPL form
#ifdef _MKL_DFTI_H_ // mkl_dfti.h has been included, we can use MKL FFT routines
// TODO
// #include "src/FFT/ei_imkl_impl.h"
@@ -70,6 +70,20 @@ class FFT
fwd( &dst[0],&src[0],src.size() );
}
+ template<typename InputDerived, typename ComplexDerived>
+ void fwd( MatrixBase<ComplexDerived> & dst, const MatrixBase<InputDerived> & src)
+ {
+ EIGEN_STATIC_ASSERT_VECTOR_ONLY(InputDerived)
+ EIGEN_STATIC_ASSERT_VECTOR_ONLY(ComplexDerived)
+ EIGEN_STATIC_ASSERT_SAME_VECTOR_SIZE(ComplexDerived,InputDerived) // size at compile-time
+ EIGEN_STATIC_ASSERT((ei_is_same_type<typename ComplexDerived::Scalar, Complex>::ret),
+ YOU_MIXED_DIFFERENT_NUMERIC_TYPES__YOU_NEED_TO_USE_THE_CAST_METHOD_OF_MATRIXBASE_TO_CAST_NUMERIC_TYPES_EXPLICITLY)
+ EIGEN_STATIC_ASSERT(int(InputDerived::Flags)&int(ComplexDerived::Flags)&DirectAccessBit,
+ THIS_METHOD_IS_ONLY_FOR_EXPRESSIONS_WITH_DIRECT_MEMORY_ACCESS_SUCH_AS_MAP_OR_PLAIN_MATRICES)
+ dst.derived().resize( src.size() );
+ fwd( &dst[0],&src[0],src.size() );
+ }
+
template <typename _Output>
void inv( _Output * dst, const Complex * src, int nfft)
{
@@ -83,8 +97,24 @@ class FFT
inv( &dst[0],&src[0],src.size() );
}
+ template<typename OutputDerived, typename ComplexDerived>
+ void inv( MatrixBase<OutputDerived> & dst, const MatrixBase<ComplexDerived> & src)
+ {
+ EIGEN_STATIC_ASSERT_VECTOR_ONLY(OutputDerived)
+ EIGEN_STATIC_ASSERT_VECTOR_ONLY(ComplexDerived)
+ EIGEN_STATIC_ASSERT_SAME_VECTOR_SIZE(ComplexDerived,OutputDerived) // size at compile-time
+ EIGEN_STATIC_ASSERT((ei_is_same_type<typename ComplexDerived::Scalar, Complex>::ret),
+ YOU_MIXED_DIFFERENT_NUMERIC_TYPES__YOU_NEED_TO_USE_THE_CAST_METHOD_OF_MATRIXBASE_TO_CAST_NUMERIC_TYPES_EXPLICITLY)
+ EIGEN_STATIC_ASSERT(int(OutputDerived::Flags)&int(ComplexDerived::Flags)&DirectAccessBit,
+ THIS_METHOD_IS_ONLY_FOR_EXPRESSIONS_WITH_DIRECT_MEMORY_ACCESS_SUCH_AS_MAP_OR_PLAIN_MATRICES)
+ dst.derived().resize( src.size() );
+ inv( &dst[0],&src[0],src.size() );
+ }
+
// TODO: multi-dimensional FFTs
+
// TODO: handle Eigen MatrixBase
+ // ---> i added fwd and inv specializations above + unit test, is this enough? (bjacob)
traits_type & traits() {return m_traits;}
private:
diff --git a/unsupported/test/FFT.cpp b/unsupported/test/FFT.cpp
index f0b9b68bf..cc68f3718 100644
--- a/unsupported/test/FFT.cpp
+++ b/unsupported/test/FFT.cpp
@@ -39,16 +39,16 @@ complex<long double> promote(double x) { return complex<long double>( x); }
complex<long double> promote(long double x) { return complex<long double>( x); }
- template <typename T1,typename T2>
- long double fft_rmse( const vector<T1> & fftbuf,const vector<T2> & timebuf)
+ template <typename VectorType1,typename VectorType2>
+ long double fft_rmse( const VectorType1 & fftbuf,const VectorType2 & timebuf)
{
long double totalpower=0;
long double difpower=0;
cerr <<"idx\ttruth\t\tvalue\t|dif|=\n";
- for (size_t k0=0;k0<fftbuf.size();++k0) {
+ for (size_t k0=0;k0<size_t(fftbuf.size());++k0) {
complex<long double> acc = 0;
long double phinc = -2.*k0* M_PIl / timebuf.size();
- for (size_t k1=0;k1<timebuf.size();++k1) {
+ for (size_t k1=0;k1<size_t(timebuf.size());++k1) {
acc += promote( timebuf[k1] ) * exp( complex<long double>(0,k1*phinc) );
}
totalpower += norm(acc);
@@ -61,8 +61,8 @@ complex<long double> promote(long double x) { return complex<long double>( x);
return sqrt(difpower/totalpower);
}
- template <typename T1,typename T2>
- long double dif_rmse( const vector<T1> buf1,const vector<T2> buf2)
+ template <typename VectorType1,typename VectorType2>
+ long double dif_rmse( const VectorType1& buf1,const VectorType2& buf2)
{
long double totalpower=0;
long double difpower=0;
@@ -74,35 +74,59 @@ complex<long double> promote(long double x) { return complex<long double>( x);
return sqrt(difpower/totalpower);
}
-template <class T>
-void test_scalar(int nfft)
+enum { StdVectorContainer, EigenVectorContainer };
+
+template<int Container, typename Scalar> struct VectorType;
+
+template<typename Scalar> struct VectorType<StdVectorContainer,Scalar>
+{
+ typedef vector<Scalar> type;
+};
+
+template<typename Scalar> struct VectorType<EigenVectorContainer,Scalar>
{
- typedef typename Eigen::FFT<T>::Complex Complex;
- typedef typename Eigen::FFT<T>::Scalar Scalar;
+ typedef Matrix<Scalar,Dynamic,1> type;
+};
+
+template <int Container, typename T>
+void test_scalar_generic(int nfft)
+{
+ typedef typename FFT<T>::Complex Complex;
+ typedef typename FFT<T>::Scalar Scalar;
+ typedef typename VectorType<Container,Scalar>::type ScalarVector;
+ typedef typename VectorType<Container,Complex>::type ComplexVector;
FFT<T> fft;
- vector<Scalar> inbuf(nfft);
- vector<Complex> outbuf;
+ ScalarVector inbuf(nfft);
+ ComplexVector outbuf;
for (int k=0;k<nfft;++k)
inbuf[k]= (T)(rand()/(double)RAND_MAX - .5);
fft.fwd( outbuf,inbuf);
VERIFY( fft_rmse(outbuf,inbuf) < test_precision<T>() );// gross check
- vector<Scalar> buf3;
+ ScalarVector buf3;
fft.inv( buf3 , outbuf);
VERIFY( dif_rmse(inbuf,buf3) < test_precision<T>() );// gross check
}
-template <class T>
-void test_complex(int nfft)
+template <typename T>
+void test_scalar(int nfft)
{
- typedef typename Eigen::FFT<T>::Complex Complex;
+ test_scalar_generic<StdVectorContainer,T>(nfft);
+ test_scalar_generic<EigenVectorContainer,T>(nfft);
+}
+
+template <int Container, typename T>
+void test_complex_generic(int nfft)
+{
+ typedef typename FFT<T>::Complex Complex;
+ typedef typename VectorType<Container,Complex>::type ComplexVector;
FFT<T> fft;
- vector<Complex> inbuf(nfft);
- vector<Complex> outbuf;
- vector<Complex> buf3;
+ ComplexVector inbuf(nfft);
+ ComplexVector outbuf;
+ ComplexVector buf3;
for (int k=0;k<nfft;++k)
inbuf[k]= Complex( (T)(rand()/(double)RAND_MAX - .5), (T)(rand()/(double)RAND_MAX - .5) );
fft.fwd( outbuf , inbuf);
@@ -114,22 +138,63 @@ void test_complex(int nfft)
VERIFY( dif_rmse(inbuf,buf3) < test_precision<T>() );// gross check
}
-void test_FFT()
+template <typename T>
+void test_complex(int nfft)
{
+ test_complex_generic<StdVectorContainer,T>(nfft);
+ test_complex_generic<EigenVectorContainer,T>(nfft);
+}
- CALL_SUBTEST( test_complex<float>(32) ); CALL_SUBTEST( test_complex<double>(32) ); CALL_SUBTEST( test_complex<long double>(32) );
- CALL_SUBTEST( test_complex<float>(256) ); CALL_SUBTEST( test_complex<double>(256) ); CALL_SUBTEST( test_complex<long double>(256) );
- CALL_SUBTEST( test_complex<float>(3*8) ); CALL_SUBTEST( test_complex<double>(3*8) ); CALL_SUBTEST( test_complex<long double>(3*8) );
- CALL_SUBTEST( test_complex<float>(5*32) ); CALL_SUBTEST( test_complex<double>(5*32) ); CALL_SUBTEST( test_complex<long double>(5*32) );
- CALL_SUBTEST( test_complex<float>(2*3*4) ); CALL_SUBTEST( test_complex<double>(2*3*4) ); CALL_SUBTEST( test_complex<long double>(2*3*4) );
- CALL_SUBTEST( test_complex<float>(2*3*4*5) ); CALL_SUBTEST( test_complex<double>(2*3*4*5) ); CALL_SUBTEST( test_complex<long double>(2*3*4*5) );
- CALL_SUBTEST( test_complex<float>(2*3*4*5*7) ); CALL_SUBTEST( test_complex<double>(2*3*4*5*7) ); CALL_SUBTEST( test_complex<long double>(2*3*4*5*7) );
-
-
+void test_FFT()
+{
- CALL_SUBTEST( test_scalar<float>(32) ); CALL_SUBTEST( test_scalar<double>(32) ); CALL_SUBTEST( test_scalar<long double>(32) );
- CALL_SUBTEST( test_scalar<float>(45) ); CALL_SUBTEST( test_scalar<double>(45) ); CALL_SUBTEST( test_scalar<long double>(45) );
- CALL_SUBTEST( test_scalar<float>(50) ); CALL_SUBTEST( test_scalar<double>(50) ); CALL_SUBTEST( test_scalar<long double>(50) );
- CALL_SUBTEST( test_scalar<float>(256) ); CALL_SUBTEST( test_scalar<double>(256) ); CALL_SUBTEST( test_scalar<long double>(256) );
- CALL_SUBTEST( test_scalar<float>(2*3*4*5*7) ); CALL_SUBTEST( test_scalar<double>(2*3*4*5*7) ); CALL_SUBTEST( test_scalar<long double>(2*3*4*5*7) );
+ CALL_SUBTEST( test_complex<float>(32) );
+ CALL_SUBTEST( test_complex<double>(32) );
+ CALL_SUBTEST( test_complex<long double>(32) );
+
+ CALL_SUBTEST( test_complex<float>(256) );
+ CALL_SUBTEST( test_complex<double>(256) );
+ CALL_SUBTEST( test_complex<long double>(256) );
+
+ CALL_SUBTEST( test_complex<float>(3*8) );
+ CALL_SUBTEST( test_complex<double>(3*8) );
+ CALL_SUBTEST( test_complex<long double>(3*8) );
+
+ CALL_SUBTEST( test_complex<float>(5*32) );
+ CALL_SUBTEST( test_complex<double>(5*32) );
+ CALL_SUBTEST( test_complex<long double>(5*32) );
+
+ CALL_SUBTEST( test_complex<float>(2*3*4) );
+ CALL_SUBTEST( test_complex<double>(2*3*4) );
+ CALL_SUBTEST( test_complex<long double>(2*3*4) );
+
+ CALL_SUBTEST( test_complex<float>(2*3*4*5) );
+ CALL_SUBTEST( test_complex<double>(2*3*4*5) );
+ CALL_SUBTEST( test_complex<long double>(2*3*4*5) );
+
+ CALL_SUBTEST( test_complex<float>(2*3*4*5*7) );
+ CALL_SUBTEST( test_complex<double>(2*3*4*5*7) );
+ CALL_SUBTEST( test_complex<long double>(2*3*4*5*7) );
+
+
+
+ CALL_SUBTEST( test_scalar<float>(32) );
+ CALL_SUBTEST( test_scalar<double>(32) );
+ CALL_SUBTEST( test_scalar<long double>(32) );
+
+ CALL_SUBTEST( test_scalar<float>(45) );
+ CALL_SUBTEST( test_scalar<double>(45) );
+ CALL_SUBTEST( test_scalar<long double>(45) );
+
+ CALL_SUBTEST( test_scalar<float>(50) );
+ CALL_SUBTEST( test_scalar<double>(50) );
+ CALL_SUBTEST( test_scalar<long double>(50) );
+
+ CALL_SUBTEST( test_scalar<float>(256) );
+ CALL_SUBTEST( test_scalar<double>(256) );
+ CALL_SUBTEST( test_scalar<long double>(256) );
+
+ CALL_SUBTEST( test_scalar<float>(2*3*4*5*7) );
+ CALL_SUBTEST( test_scalar<double>(2*3*4*5*7) );
+ CALL_SUBTEST( test_scalar<long double>(2*3*4*5*7) );
}