diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2016-01-13 14:24:37 -0800 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2016-01-13 14:24:37 -0800 |
commit | 9f013a9d86ad5cf82939bfeab2223652a821c448 (patch) | |
tree | 47e8f1d08fe1e1cf79cb793e43937592782e8c9d | |
parent | 79b69b7444cfae2f7631e873e822cdca6f4e355f (diff) |
Properly record the rank of reduced tensors in the tensor traits.
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h | 2 | ||||
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h | 9 |
2 files changed, 7 insertions, 4 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h b/unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h index c783aab97..781a37e34 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h @@ -134,7 +134,7 @@ struct traits<TensorTupleReducerOp<ReduceOp, Dims, XprType> > : public traits<Xp typedef Index Scalar; typedef typename XprType::Nested Nested; typedef typename remove_reference<Nested>::type _Nested; - static const int NumDimensions = XprTraits::NumDimensions; + static const int NumDimensions = XprTraits::NumDimensions - array_size<Dims>::value; static const int Layout = XprTraits::Layout; }; diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h index 8028e71c0..2dc8815b8 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h @@ -24,11 +24,14 @@ template<typename Op, typename Dims, typename XprType> struct traits<TensorReductionOp<Op, Dims, XprType> > : traits<XprType> { - typedef typename traits<XprType>::Scalar Scalar; + typedef traits<XprType> XprTraits; + typedef typename XprTraits::Scalar Scalar; typedef typename internal::packet_traits<Scalar>::type Packet; - typedef typename traits<XprType>::StorageKind StorageKind; - typedef typename traits<XprType>::Index Index; + typedef typename XprTraits::StorageKind StorageKind; + typedef typename XprTraits::Index Index; typedef typename XprType::Nested Nested; + static const int NumDimensions = XprTraits::NumDimensions - array_size<Dims>::value; + static const int Layout = XprTraits::Layout; }; template<typename Op, typename Dims, typename XprType> |