aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
diff options
context:
space:
mode:
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h14
1 files changed, 4 insertions, 10 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
index c459fc649..f0f61fade 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
@@ -152,13 +152,7 @@ struct TensorContractionParams {
// 1. Elementwise Relu transformation following Conv2D.
// 2. AddBias to the Conv2D output channels dimension.
//
-// See expected implementation in NoOpOutputKernel.
-struct OutputKernel {
- template <typename Index, typename Scalar>
- using OutputMapper = internal::blas_data_mapper<Scalar, Index, ColMajor>;
-};
-
-// Output kernel that does absolutely nothing.
+// The NoOpOutputKernel implements an output kernel that does absolutely nothing.
struct NoOpOutputKernel {
/**
* Tensor contraction evaluator calls this kernel after finishing each block
@@ -177,7 +171,7 @@ struct NoOpOutputKernel {
*/
template <typename Index, typename Scalar>
EIGEN_ALWAYS_INLINE void operator()(
- const OutputKernel::OutputMapper<Index, Scalar>& /*output_mapper*/,
+ const internal::blas_data_mapper<Scalar, Index, ColMajor>& /*output_mapper*/,
const TensorContractionParams& /*params*/, Index /*i*/,
Index /*j*/, Index /*num_rows*/, Index /*num_cols*/) const {}
};
@@ -354,7 +348,7 @@ struct TensorContractionEvaluatorBase
// dimensions and right non-contracting dimensions.
m_lhs_inner_dim_contiguous = true;
int dim_idx = 0;
- unsigned int nocontract_idx = 0;
+ Index nocontract_idx = 0;
for (int i = 0; i < LDims; i++) {
// find if we are contracting on index i of left tensor
@@ -667,7 +661,7 @@ struct TensorContractionEvaluatorBase
// call gebp (matrix kernel)
// The parameters here are copied from Eigen's GEMM implementation
- const auto output_mapper = output.getSubMapper(i2, j2);
+ const OutputMapper output_mapper = output.getSubMapper(i2, j2);
gebp(output_mapper, blockA, blockB, actual_mc, actual_kc, actual_nc,
Scalar(1), -1, -1, 0, 0);