aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc
blob: 8dc5f3c93b6ba1a722ea7b23b4b5190ac0600cd6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#if defined(INTEL_MKL) && !defined(INTEL_MKL_DNN_ONLY)
#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h"
#include "third_party/intel_mkl_ml/include/mkl_cblas.h"
#include "third_party/intel_mkl_ml/include/mkl_service.h"

#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/core/platform/types.h"

#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool"
#include "tensorflow/core/platform/dynamic_annotations.h"

using tensorflow::int32;
using tensorflow::int64;

namespace {
// BLAS GEMM API for 32-bit Matrix Multiplication.

// MatMul function is defined as: c = alpha * op(a) * op(b) + beta * c.
// Since XLA MatMul does not used alpha, beta, we set them to 1.0 and 0.0.
// Matrix lhs, rhs and out are all colum-major.
void MatMulF32(const void* run_options_ptr, float* out, float* lhs, float* rhs,
               int64 m, int64 n, int64 k, int32 transpose_lhs,
               int32 transpose_rhs) {
  const float alpha = 1.0f, beta = 0.0f;
  // lda, ldb, and ldc are the leading dimensions of matrices a, b, and c,
  // respectively. For column-major matrices, the leading dimension is the
  // stride between consecutive columns (which equals the number of rows). If
  // the matrix is transposed, the leading dimension is the stride between
  // consecutive rows (which equals the number of columns).
  int lda = transpose_lhs ? k : m;
  int ldb = transpose_rhs ? n : k;
  int ldc = m;
  cblas_sgemm(CblasColMajor, transpose_lhs ? CblasTrans : CblasNoTrans,
              transpose_rhs ? CblasTrans : CblasNoTrans, m, n, k, alpha, lhs,
              lda, rhs, ldb, beta, out, ldc);
}

// BLAS GEMM API for 64-bit Matrix Multiplication.

// MatMul function is defined as: c = alpha * op(a) * op(b) + beta * c.
// Since XLA MatMul does not used alpha, beta, we set them to 1.0 and 0.0.
// Matrix lhs, rhs and out are all colum-major.
void MatMulF64(const void* run_options_ptr, double* out, double* lhs,
               double* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs,
               int32 transpose_rhs) {
  const float alpha = 1.0f, beta = 0.0f;
  // lda, ldb, and ldc are the leading dimensions of matrices a, b, and c,
  // respectively. For a column-major matrix, the leading dimension is the
  // stride between consecutive columns (which equals the number of rows). If
  // the matrix is transposed, the leading dimension is the stride between
  // consecutive rows (which equals the number of columns).
  int lda = transpose_lhs ? k : m;
  int ldb = transpose_rhs ? n : k;
  int ldc = m;
  cblas_dgemm(CblasColMajor, transpose_lhs ? CblasTrans : CblasNoTrans,
              transpose_rhs ? CblasTrans : CblasNoTrans, m, n, k, alpha, lhs,
              lda, rhs, ldb, beta, out, ldc);
}

}  // namespace

TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_MKLMatMulF32(
    const void* run_options_ptr, float* out, float* lhs, float* rhs, int64 m,
    int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) {
  const xla::ExecutableRunOptions* run_options =
      static_cast<const xla::ExecutableRunOptions*>(run_options_ptr);
  // BLAS GEMM MatMul uses OpenMP for parallelization, so we pass the thread
  // number specified in intra_op_thread_pool to MKL.
  int prev_num_threads = mkl_set_num_threads_local(
      run_options->intra_op_thread_pool()->numThreads());
  MatMulF32(nullptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
  // Set thread number back to the previous number.
  mkl_set_num_threads_local(prev_num_threads);
}

// BLAS GEMM API for 64-bit Matrix Multiplication
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_MKLMatMulF64(
    const void* run_options_ptr, double* out, double* lhs, double* rhs, int64 m,
    int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) {
  const xla::ExecutableRunOptions* run_options =
      static_cast<const xla::ExecutableRunOptions*>(run_options_ptr);
  // BLAS GEMM MatMul uses OpenMP for parallelization, so we pass the thread
  // number specified in intra_op_thread_pool to MKL.
  int prev_num_threads = mkl_set_num_threads_local(
      run_options->intra_op_thread_pool()->numThreads());
  MatMulF64(nullptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
  // Set thread number back to the previous number.
  mkl_set_num_threads_local(prev_num_threads);
}

TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
__xla_cpu_runtime_MKLSingleThreadedMatMulF32(const void* run_options_ptr,
                                             float* out, float* lhs, float* rhs,
                                             int64 m, int64 n, int64 k,
                                             int32 transpose_lhs,
                                             int32 transpose_rhs) {
  // Set the thread number to 1 for single threaded excution.
  int prev_num_threads = mkl_set_num_threads_local(1);
  MatMulF32(nullptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
  // Set thread number back to the previous number.
  mkl_set_num_threads_local(prev_num_threads);
}

TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
__xla_cpu_runtime_MKLSingleThreadedMatMulF64(const void* run_options_ptr,
                                             double* out, double* lhs,
                                             double* rhs, int64 m, int64 n,
                                             int64 k, int32 transpose_lhs,
                                             int32 transpose_rhs) {
  // Set the thread number to 1 for single threaded excution.
  int prev_num_threads = mkl_set_num_threads_local(1);
  MatMulF64(nullptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
  // Set thread number back to the previous number.
  mkl_set_num_threads_local(prev_num_threads);
}
#endif  // INTEL_MKL