aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2016-01-22 14:37:26 -0800
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2016-01-22 14:37:26 -0800
commit4beb447e27baaa19081e835bd6aba76e9b02cc67 (patch)
tree5c37885e2748623e69dd663db4194409ff056e2f /unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h
parent5358c3858963e03581640e58ea1f3adbdd03b831 (diff)
Created a mechanism to enable contraction mappers to determine the best blocking strategy.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h58
1 files changed, 58 insertions, 0 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h
new file mode 100644
index 000000000..78ed5038f
--- /dev/null
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h
@@ -0,0 +1,58 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_BLOCKING_H
+#define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_BLOCKING_H
+
+
+namespace Eigen {
+namespace internal {
+
+enum {
+ ShardByRow = 0,
+ ShardByCol = 1
+};
+
+
+// Default Blocking Strategy
+template <typename LhsMapper, typename RhsMapper, typename Index, int ShardingType=ShardByCol>
+class TensorContractionBlocking {
+ public:
+
+ typedef typename LhsMapper::Scalar LhsScalar;
+ typedef typename RhsMapper::Scalar RhsScalar;
+
+ TensorContractionBlocking(Index k, Index m, Index n, Index num_threads = 1) :
+ kc_(k), mc_(m), nc_(n)
+ {
+ if (ShardingType == ShardByCol) {
+ computeProductBlockingSizes<LhsScalar, RhsScalar, 1>(kc_, mc_, nc_, num_threads);
+ }
+ else {
+ if (kc_ && mc_ && nc_) {
+ mc_ = (((m / num_threads) + 15) / 16) * 16;
+ }
+ }
+ }
+
+ EIGEN_ALWAYS_INLINE Index kc() const { return kc_; }
+ EIGEN_ALWAYS_INLINE Index mc() const { return mc_; }
+ EIGEN_ALWAYS_INLINE Index nc() const { return nc_; }
+
+ private:
+ Index kc_;
+ Index mc_;
+ Index nc_;
+};
+
+
+} // end namespace internal
+} // end namespace Eigen
+
+#endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_BLOCKING_H