aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h
diff options
context:
space:
mode:
authorGravatar Rasmus Munk Larsen <rmlarsen@google.com>2016-04-14 13:57:35 -0700
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2016-04-14 13:57:35 -0700
commit235e83aba608cf3d94b033bfbf551f8c136a3fab (patch)
tree7b011fee8fe18b605320c69e75995cf8521fbdf4 /unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h
parent3551dea887ce60756c28796e83bb7c080f2b2782 (diff)
Eigen cost model part 1. This implements a basic recursive framework to estimate the cost of evaluating tensor expressions.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h13
1 files changed, 13 insertions, 0 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h b/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h
index b7c13f67f..ccaa757d1 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h
@@ -594,6 +594,8 @@ template <> class UniformRandomGenerator<std::complex<double> > {
template <typename Scalar>
struct functor_traits<UniformRandomGenerator<Scalar> > {
enum {
+ // Rough estimate.
+ Cost = 100 * NumTraits<Scalar>::MulCost,
PacketAccess = UniformRandomGenerator<Scalar>::PacketAccess
};
};
@@ -774,6 +776,8 @@ template <typename T> class NormalRandomGenerator {
template <typename Scalar>
struct functor_traits<NormalRandomGenerator<Scalar> > {
enum {
+ // Rough estimate.
+ Cost = 100 * NumTraits<Scalar>::MulCost,
PacketAccess = NormalRandomGenerator<Scalar>::PacketAccess
};
};
@@ -807,6 +811,15 @@ class GaussianGenerator {
array<T, NumDims> m_two_sigmas;
};
+template <typename T, typename Index, size_t NumDims>
+struct functor_traits<GaussianGenerator<T, Index, NumDims> > {
+ enum {
+ Cost = NumDims * (2 * NumTraits<T>::AddCost + NumTraits<T>::MulCost +
+ functor_traits<scalar_quotient_op<T, T> >::Cost) +
+ functor_traits<scalar_exp_op<T> >::Cost,
+ PacketAccess = GaussianGenerator<T, Index, NumDims>::PacketAccess
+ };
+};
} // end namespace internal
} // end namespace Eigen