aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-10-26 14:29:26 -0700
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-10-26 14:29:26 -0700
commit1c8312c811344beaa06f7ae9258f66c38337c607 (patch)
tree4436a04ce900a997aa6adf9d42d0c6dab0a07fac /unsupported/Eigen/CXX11/src/Tensor
parent1f4c98abb1634bdbdd6583b55ba36dcc09ef5773 (diff)
Started to add support for tensors of rank 0
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/Tensor.h31
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorIO.h5
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorInitializer.h12
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorStorage.h12
4 files changed, 57 insertions, 3 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/Tensor.h b/unsupported/Eigen/CXX11/src/Tensor/Tensor.h
index 3ac465d24..0df1345c2 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/Tensor.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/Tensor.h
@@ -140,6 +140,12 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
}
#endif
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& coeff() const
+ {
+ EIGEN_STATIC_ASSERT(NumIndices == 0, YOU_MADE_A_PROGRAMMING_MISTAKE);
+ return m_storage.data()[0];
+ }
+
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& coeff(Index index) const
{
eigen_internal_assert(index >= 0 && index < size());
@@ -174,6 +180,12 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
}
#endif
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef()
+ {
+ EIGEN_STATIC_ASSERT(NumIndices == 0, YOU_MADE_A_PROGRAMMING_MISTAKE);
+ return m_storage.data()[0];
+ }
+
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index)
{
eigen_internal_assert(index >= 0 && index < size());
@@ -234,6 +246,12 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
return coeff(index);
}
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& operator()() const
+ {
+ EIGEN_STATIC_ASSERT(NumIndices == 0, YOU_MADE_A_PROGRAMMING_MISTAKE);
+ return coeff();
+ }
+
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& operator[](Index index) const
{
// The bracket operator is only for vectors, use the parenthesis operator instead.
@@ -295,6 +313,12 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
return coeffRef(index);
}
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& operator()()
+ {
+ EIGEN_STATIC_ASSERT(NumIndices == 0, YOU_MADE_A_PROGRAMMING_MISTAKE);
+ return coeffRef();
+ }
+
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& operator[](Index index)
{
// The bracket operator is only for vectors, use the parenthesis operator instead
@@ -433,6 +457,13 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
resize(dims);
}
+ EIGEN_DEVICE_FUNC
+ void resize()
+ {
+ EIGEN_STATIC_ASSERT(NumIndices == 0, YOU_MADE_A_PROGRAMMING_MISTAKE);
+ // Nothing to do: rank 0 tensors have fixed size
+ }
+
/** Custom Dimension */
#ifdef EIGEN_HAS_SFINAE
template<typename CustomDimension,
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorIO.h b/unsupported/Eigen/CXX11/src/Tensor/TensorIO.h
index 3b6f2c730..38a833f82 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorIO.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorIO.h
@@ -33,7 +33,10 @@ std::ostream& operator << (std::ostream& os, const TensorBase<T, ReadOnlyAccesso
const Index total_size = internal::array_prod(tensor.dimensions());
// Print the tensor as a 1d vector or a 2d matrix.
- if (internal::array_size<Dimensions>::value == 1) {
+ static const int rank = internal::array_size<Dimensions>::value;
+ if (rank == 0) {
+ os << tensor.coeff(0);
+ } else if (rank == 1) {
Map<const Array<Scalar, Dynamic, 1> > array(const_cast<Scalar*>(tensor.data()), total_size);
os << array;
} else {
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorInitializer.h b/unsupported/Eigen/CXX11/src/Tensor/TensorInitializer.h
index 4303e3536..ad2a1e6ac 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorInitializer.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorInitializer.h
@@ -55,6 +55,18 @@ struct Initializer<Derived, 1> {
}
};
+template <typename Derived>
+struct Initializer<Derived, 0> {
+ typedef typename traits<Derived>::Scalar InitList;
+
+ static void run(TensorEvaluator<Derived, DefaultDevice>& tensor,
+ Eigen::array<typename traits<Derived>::Index, traits<Derived>::NumDimensions>*/* indices*/,
+ const InitList& v) {
+ tensor.coeffRef(0) = v;
+ }
+};
+
+
template <typename Derived, int N>
void initialize_tensor(TensorEvaluator<Derived, DefaultDevice>& tensor,
const typename Initializer<Derived, traits<Derived>::NumDimensions>::InitList& vals) {
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorStorage.h b/unsupported/Eigen/CXX11/src/Tensor/TensorStorage.h
index 9e4cf039d..ee6f14b8f 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorStorage.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorStorage.h
@@ -71,7 +71,11 @@ class TensorStorage<T, DSizes<IndexType, NumIndices_>, Options_>
typedef DSizes<IndexType, NumIndices_> Dimensions;
typedef TensorStorage<T, DSizes<IndexType, NumIndices_>, Options_> Self;
- EIGEN_DEVICE_FUNC TensorStorage() : m_data(0), m_dimensions() {}
+ EIGEN_DEVICE_FUNC TensorStorage() : m_data(0), m_dimensions() {
+ if (NumIndices_ == 0) {
+ m_data = internal::conditional_aligned_new_auto<T,(Options_&DontAlign)==0>(1);
+ }
+ }
EIGEN_DEVICE_FUNC TensorStorage(internal::constructor_without_unaligned_array_assert)
: m_data(0), m_dimensions(internal::template repeat<NumIndices_, Index>(0)) {}
EIGEN_DEVICE_FUNC TensorStorage(Index size, const array<Index, NumIndices_>& dimensions)
@@ -101,13 +105,17 @@ class TensorStorage<T, DSizes<IndexType, NumIndices_>, Options_>
EIGEN_DEVICE_FUNC void resize(Index size, const array<Index, NumIndices_>& nbDimensions)
{
+ eigen_assert(size >= 1);
const Index currentSz = internal::array_prod(m_dimensions);
if(size != currentSz)
{
internal::conditional_aligned_delete_auto<T,(Options_&DontAlign)==0>(m_data, currentSz);
if (size)
m_data = internal::conditional_aligned_new_auto<T,(Options_&DontAlign)==0>(size);
- else
+ else if (NumIndices_ == 0) {
+ m_data = internal::conditional_aligned_new_auto<T,(Options_&DontAlign)==0>(1);
+ }
+ else
m_data = 0;
EIGEN_INTERNAL_DENSE_STORAGE_CTOR_PLUGIN
}