aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <eugene.zhulenev@gmail.com>2019-12-18 20:07:00 +0000
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2019-12-18 20:07:00 +0000
commitae07801dd8d295657f28b006e1e4999edf835052 (patch)
tree08a91a4368c15d365127344f920bd10f8e437db2 /unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
parent72166d0e6eaf12a99f449e26f402f926bef2bb50 (diff)
Tensor block evaluation cost model
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h23
1 files changed, 16 insertions, 7 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
index 146cc325e..d4532b72c 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
@@ -521,7 +521,9 @@ struct TensorEvaluator<const TensorCwiseUnaryOp<UnaryOp, ArgType>, Device>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
internal::TensorBlockResourceRequirements getResourceRequirements() const {
- return m_argImpl.getResourceRequirements();
+ static const double functor_cost = internal::functor_traits<UnaryOp>::Cost;
+ return m_argImpl.getResourceRequirements().addCostPerCoeff(
+ {0, 0, functor_cost / PacketSize});
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlock
@@ -654,9 +656,11 @@ struct TensorEvaluator<const TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArg
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
internal::TensorBlockResourceRequirements getResourceRequirements() const {
+ static const double functor_cost = internal::functor_traits<BinaryOp>::Cost;
return internal::TensorBlockResourceRequirements::merge(
- m_leftImpl.getResourceRequirements(),
- m_rightImpl.getResourceRequirements());
+ m_leftImpl.getResourceRequirements(),
+ m_rightImpl.getResourceRequirements())
+ .addCostPerCoeff({0, 0, functor_cost / PacketSize});
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlock
@@ -934,11 +938,16 @@ struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
internal::TensorBlockResourceRequirements getResourceRequirements() const {
+ auto then_req = m_thenImpl.getResourceRequirements();
+ auto else_req = m_elseImpl.getResourceRequirements();
+
+ auto merged_req =
+ internal::TensorBlockResourceRequirements::merge(then_req, else_req);
+ merged_req.cost_per_coeff =
+ then_req.cost_per_coeff.cwiseMax(else_req.cost_per_coeff);
+
return internal::TensorBlockResourceRequirements::merge(
- m_condImpl.getResourceRequirements(),
- internal::TensorBlockResourceRequirements::merge(
- m_thenImpl.getResourceRequirements(),
- m_elseImpl.getResourceRequirements()));
+ m_condImpl.getResourceRequirements(), merged_req);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlock