diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionMkldnn.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorContractionMkldnn.h | 116 |
1 files changed, 116 insertions, 0 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMkldnn.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMkldnn.h new file mode 100644 index 000000000..a97f043c1 --- /dev/null +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMkldnn.h @@ -0,0 +1,116 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2018 Eugene Zhulenev <ezhulenev@google.com> +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. +#ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MKLDNN_H +#define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MKLDNN_H + +#if defined(EIGEN_USE_MKLDNN) +// Support for MklDnn sgemm kernel in Tensor contractions: +// +// 1. Prepare packed Lhs/Rhs blocks from tensor expressions using +// DataMapper (see TensorContractionInputMapper). +// 2. Invoke gemm kernel with packed blocks (replacement for default +// gebp_kernel). + +namespace Eigen { +namespace internal { + +template <typename Scalar, typename StorageIndex, typename DataMapper, + int StorageOrder> +struct mkldnn_gemm_pack; + +// mkl_gemm_pack for ColMajor storage order. +template <typename Scalar, typename StorageIndex, typename DataMapper> +struct mkldnn_gemm_pack<Scalar, StorageIndex, DataMapper, + /*StorageOrder*/ ColMajor> { + typedef typename internal::packet_traits<Scalar>::type Packet; + typedef typename DataMapper::LinearMapper LinearMapper; + + enum { PacketSize = internal::packet_traits<Scalar>::size }; + + EIGEN_DONT_INLINE + void operator()(Scalar *block, const DataMapper &data_mapper, + StorageIndex rows, StorageIndex cols) { + const StorageIndex unrolled_rows = + (rows / (4 * PacketSize)) * (4 * PacketSize); + const StorageIndex vectorized_rows = (rows / PacketSize) * PacketSize; + + for (StorageIndex col = 0; col < cols; ++col) { + LinearMapper lm = data_mapper.getLinearMapper(0, col); + + // Give compiler a strong possibility to unroll the loop. + for (StorageIndex i = 0; i < unrolled_rows; i += 4 * PacketSize) { + for (StorageIndex j = 0; j < 4; ++j) { + const Packet p = lm.template loadPacket<Packet>(i + j * PacketSize); + internal::pstoreu(block + j * PacketSize, p); + } + block += 4 * PacketSize; + } + + // Process remaining rows with packets. + for (StorageIndex i = unrolled_rows; i < vectorized_rows; + i += PacketSize) { + const Packet p = lm.template loadPacket<Packet>(i); + internal::pstoreu(block, p); + block += PacketSize; + } + + // Finalize with coefficients. + for (StorageIndex i = vectorized_rows; i < rows; ++i) { + *block = lm(i); + ++block; + } + } + } +}; + +template <typename Scalar, typename StorageIndex, typename OutputMapper, + bool ConjugateLhs = false, bool ConjugateRhs = false> +struct mkldnn_gemm_kernel; + +// mkldnn_gemm_kernel for floats defined as a thin layer on top of mkldnn_sgemm. +template <typename StorageIndex, typename OutputMapper, bool ConjugateLhs, + bool ConjugateRhs> +struct mkldnn_gemm_kernel</*Scalar*/ float, StorageIndex, OutputMapper, + ConjugateLhs, ConjugateRhs> { + EIGEN_DONT_INLINE + void operator()(const OutputMapper &output, const float *blockA, + const float *blockB, const StorageIndex rows, + const StorageIndex depth, const StorageIndex cols, + float alpha) { + static const int max_index = (std::numeric_limits<int>::max)(); + + eigen_assert(max_index > rows); + eigen_assert(max_index > cols); + eigen_assert(max_index > depth); + eigen_assert(max_index > output.stride()); + + const int m = static_cast<int>(rows); + const int n = static_cast<int>(cols); + const int k = static_cast<int>(depth); + + const char transposeA = ConjugateLhs ? 'Y' : 'N'; + const char transposeB = ConjugateRhs ? 'Y' : 'N'; + + const int ldA = ConjugateLhs ? k : m; + const int ldB = ConjugateRhs ? n : k; + const int ldC = static_cast<int>(output.stride()); + + const float beta = 1.0; + + mkldnn_status_t st = mkldnn_sgemm(&transposeA, &transposeB, &m, &n, &k, + &alpha, blockA, &ldA, blockB, &ldB, &beta, + const_cast<float*>(output.data()), &ldC); + eigen_assert(st == 0); + } +}; + +} // namespace internal +} // namespace Eigen +#endif // EIGEN_USE_MKLDNN +#endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MKLDNN_H |