diff options
-rw-r--r-- | unsupported/Eigen/src/NumericalDiff/NumericalDiff.h | 51 | ||||
-rw-r--r-- | unsupported/test/NumericalDiff.cpp | 19 |
2 files changed, 59 insertions, 11 deletions
diff --git a/unsupported/Eigen/src/NumericalDiff/NumericalDiff.h b/unsupported/Eigen/src/NumericalDiff/NumericalDiff.h index 223cf9e0f..276b315f8 100644 --- a/unsupported/Eigen/src/NumericalDiff/NumericalDiff.h +++ b/unsupported/Eigen/src/NumericalDiff/NumericalDiff.h @@ -28,7 +28,13 @@ namespace Eigen { -template<typename Functor> class NumericalDiff : public Functor +enum NumericalDiffMode { + Forward, + Central +}; + + +template<typename Functor, NumericalDiffMode mode=Forward> class NumericalDiff : public Functor { public: typedef typename Functor::Scalar Scalar; @@ -62,14 +68,23 @@ public: int nfev=0; const int n = _x.size(); const Scalar eps = ei_sqrt((std::max(epsfcn,epsilon<Scalar>() ))); - ValueType val, fx; + ValueType val1, val2; InputType x = _x; // TODO : we should do this only if the size is not already known - val.resize(Functor::values()); - fx.resize(Functor::values()); + val1.resize(Functor::values()); + val2.resize(Functor::values()); - // compute f(x) - Functor::operator()(x, fx); + switch(mode) { + case Forward: + // compute f(x) + Functor::operator()(x, val1); nfev++; + break; + case Central: + // do nothing + break; + default: + assert(false); + }; /* Function Body */ @@ -78,11 +93,25 @@ public: if (h == 0.) { h = eps; } - x[j] += h; - Functor::operator()(x, val); - nfev++; - x[j] = _x[j]; - jac.col(j) = (val-fx)/h; + switch(mode) { + case Forward: + x[j] += h; + Functor::operator()(x, val2); + nfev++; + x[j] = _x[j]; + jac.col(j) = (val2-val1)/h; + break; + case Central: + x[j] += h; + Functor::operator()(x, val2); nfev++; + x[j] -= 2*h; + Functor::operator()(x, val1); nfev++; + x[j] = _x[j]; + jac.col(j) = (val2-val1)/(2*h); + break; + default: + assert(false); + }; } return nfev; } diff --git a/unsupported/test/NumericalDiff.cpp b/unsupported/test/NumericalDiff.cpp index 1bc9e614a..ba9f4331d 100644 --- a/unsupported/test/NumericalDiff.cpp +++ b/unsupported/test/NumericalDiff.cpp @@ -88,8 +88,27 @@ void test_forward() VERIFY_IS_APPROX(jac, actual_jac); } +void test_central() +{ + VectorXd x(3); + MatrixXd jac(15,3); + MatrixXd actual_jac(15,3); + my_functor functor; + + x << 0.082, 1.13, 2.35; + + // real one + functor.df(x, actual_jac); + + // using NumericalDiff + NumericalDiff<my_functor,Central> numDiff(functor); + numDiff.df(x, jac); + + VERIFY_IS_APPROX(jac, actual_jac); +} void test_NumericalDiff() { CALL_SUBTEST(test_forward()); + CALL_SUBTEST(test_central()); } |