aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <ezhulenev@google.com>2019-05-16 16:15:45 -0700
committerGravatar Eugene Zhulenev <ezhulenev@google.com>2019-05-16 16:15:45 -0700
commit96a276803c68274396af1e3411bc6d3f6921f8c7 (patch)
tree0534fe65bb2658cbd646e583adbc54de158217a1 /unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h
parentc8d8d5c0fcfe31eb43005245e36627e104ad2e5f (diff)
Always evaluate Tensor expressions with broadcasting via tiled evaluation code path
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h57
1 files changed, 52 insertions, 5 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h b/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h
index e2ff11129..d57203ad9 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h
@@ -29,6 +29,50 @@ namespace Eigen {
namespace internal {
/**
+ * Evaluating TensorBroadcastingOp via coefficient of packet path is extremely
+ * expensive. If expression has at least one broadcast op in it, and it supports
+ * block based evaluation, we always prefer it, even for the small tensors. For
+ * all other tileable ops, block evaluation overhead for small tensors (fits
+ * into L1) is too large, and we fallback on vectorized evaluation.
+ */
+
+// TODO(ezhulenev): Add specializations for all other types of Tensor ops.
+
+template<typename Expression>
+struct ExpressionHasTensorBroadcastingOp {
+ enum { value = false };
+};
+
+template<typename LhsXprType, typename RhsXprType>
+struct ExpressionHasTensorBroadcastingOp<
+ const TensorAssignOp<LhsXprType, RhsXprType> > {
+ enum { value = ExpressionHasTensorBroadcastingOp<RhsXprType>::value };
+};
+
+template<typename UnaryOp, typename XprType>
+struct ExpressionHasTensorBroadcastingOp<
+ const TensorCwiseUnaryOp<UnaryOp, XprType> > {
+ enum { value = ExpressionHasTensorBroadcastingOp<XprType>::value };
+};
+
+template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
+struct ExpressionHasTensorBroadcastingOp<
+ const TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> > {
+ enum {
+ value = ExpressionHasTensorBroadcastingOp<LhsXprType>::value ||
+ ExpressionHasTensorBroadcastingOp<RhsXprType>::value
+ };
+};
+
+template<typename Broadcast, typename XprType>
+struct ExpressionHasTensorBroadcastingOp<
+ const TensorBroadcastingOp<Broadcast, XprType> > {
+ enum { value = true };
+};
+
+// -------------------------------------------------------------------------- //
+
+/**
* Default strategy: the expression is evaluated sequentially with a single cpu
* thread, without vectorization and block evaluation.
*/
@@ -121,11 +165,12 @@ class TensorExecutor<Expression, DefaultDevice, Vectorizable,
Index total_size = array_prod(evaluator.dimensions());
Index cache_size = device.firstLevelCacheSize() / sizeof(Scalar);
- if (total_size < cache_size) {
+ if (total_size < cache_size
+ && !ExpressionHasTensorBroadcastingOp<Expression>::value) {
// TODO(andydavis) Reduce block management overhead for small tensors.
- // TODO(wuke) Do not do this when evaluating TensorBroadcastingOp.
internal::TensorExecutor<Expression, DefaultDevice, Vectorizable,
- /*Tileable*/ false>::run(expr, device);
+ /*Tileable*/ false>::run(expr, device);
+ evaluator.cleanup();
return;
}
@@ -260,10 +305,12 @@ class TensorExecutor<Expression, ThreadPoolDevice, Vectorizable, /*Tileable*/ tr
Evaluator evaluator(expr, device);
Index total_size = array_prod(evaluator.dimensions());
Index cache_size = device.firstLevelCacheSize() / sizeof(Scalar);
- if (total_size < cache_size) {
+
+ if (total_size < cache_size
+ && !ExpressionHasTensorBroadcastingOp<Expression>::value) {
// TODO(andydavis) Reduce block management overhead for small tensors.
internal::TensorExecutor<Expression, ThreadPoolDevice, Vectorizable,
- false>::run(expr, device);
+ /*Tileable*/ false>::run(expr, device);
evaluator.cleanup();
return;
}