diff options
-rw-r--r-- | src/Core/Identity.h | 15 | ||||
-rw-r--r-- | src/Core/Object.h | 2 | ||||
-rw-r--r-- | test/basicstuff.cpp | 2 |
3 files changed, 10 insertions, 9 deletions
diff --git a/src/Core/Identity.h b/src/Core/Identity.h index 78b380b04..4b437a844 100644 --- a/src/Core/Identity.h +++ b/src/Core/Identity.h @@ -36,30 +36,31 @@ template<typename MatrixType> class Identity static const int RowsAtCompileTime = MatrixType::RowsAtCompileTime, ColsAtCompileTime = MatrixType::ColsAtCompileTime; - Identity(int rows, int cols) : m_rows(rows), m_cols(cols) + Identity(int rows) : m_rows(rows) { - assert(rows > 0 && cols > 0); + assert(rows > 0); + assert(RowsAtCompileTime == ColsAtCompileTime); } private: Identity& _ref() { return *this; } const Identity& _constRef() const { return *this; } int _rows() const { return m_rows; } - int _cols() const { return m_cols; } + int _cols() const { return m_rows; } Scalar _read(int row, int col) const { - return static_cast<Scalar>(row == col); + return row == col ? static_cast<Scalar>(1) : static_cast<Scalar>(0); } protected: - int m_rows, m_cols; + int m_rows; }; template<typename Scalar, typename Derived> -Identity<Derived> Object<Scalar, Derived>::identity(int rows, int cols) +Identity<Derived> Object<Scalar, Derived>::identity(int rows) { - return Identity<Derived>(rows, cols); + return Identity<Derived>(rows); } #endif // EI_IDENTITY_H diff --git a/src/Core/Object.h b/src/Core/Object.h index 5b2e34f53..5792385e6 100644 --- a/src/Core/Object.h +++ b/src/Core/Object.h @@ -116,7 +116,7 @@ template<typename Scalar, typename Derived> class Object static Zero<Derived> zero(int rows = RowsAtCompileTime, int cols = ColsAtCompileTime); static Identity<Derived> - identity(int rows = RowsAtCompileTime, int cols = ColsAtCompileTime); + identity(int rows = RowsAtCompileTime); template<typename OtherDerived> bool isApprox( diff --git a/test/basicstuff.cpp b/test/basicstuff.cpp index c21a58979..3e734a498 100644 --- a/test/basicstuff.cpp +++ b/test/basicstuff.cpp @@ -37,7 +37,7 @@ template<typename MatrixType> void basicStuff(const MatrixType& m) m3, mzero = MatrixType::zero(rows, cols), identity = Matrix<Scalar, MatrixType::RowsAtCompileTime, MatrixType::RowsAtCompileTime> - ::identity(rows, rows), + ::identity(rows), square = Matrix<Scalar, MatrixType::RowsAtCompileTime, MatrixType::RowsAtCompileTime> ::random(rows, rows); VectorType v1 = VectorType::random(rows), |