aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_matmul_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/mkl_matmul_op.cc')
-rw-r--r--tensorflow/core/kernels/mkl_matmul_op.cc28
1 files changed, 14 insertions, 14 deletions
diff --git a/tensorflow/core/kernels/mkl_matmul_op.cc b/tensorflow/core/kernels/mkl_matmul_op.cc
index 47598f443f..dfa6cecc9b 100644
--- a/tensorflow/core/kernels/mkl_matmul_op.cc
+++ b/tensorflow/core/kernels/mkl_matmul_op.cc
@@ -170,32 +170,32 @@ class MklMatMulOp : public OpKernel {
// Matrix-Matrix Multiplication with Complex64 (std::complex<float>) tensors.
// For detailed info about parameters, look at FP32 function description.
void MklBlasGemm(bool transa, bool transb, const int m, const int n,
- const int k, const std::complex<float>* a, const int lda,
- const std::complex<float>* b, const int ldb,
- std::complex<float>* c, int const ldc) {
+ const int k, const complex64* a, const int lda,
+ const complex64* b, const int ldb, complex64* c,
+ int const ldc) {
const MKL_Complex8 alpha = {1.0f, 0.0f};
const MKL_Complex8 beta = {0.0f, 0.0f};
cblas_cgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
- transb ? CblasTrans : CblasNoTrans, m, n, k,
- static_cast<const void*>(&alpha), static_cast<const void*>(a),
- lda, static_cast<const void*>(b), ldb,
- static_cast<const void*>(&beta), static_cast<void*>(c), ldc);
+ transb ? CblasTrans : CblasNoTrans, m, n, k, &alpha,
+ reinterpret_cast<const MKL_Complex8*>(a), lda,
+ reinterpret_cast<const MKL_Complex8*>(b), ldb, &beta,
+ reinterpret_cast<MKL_Complex8*>(c), ldc);
}
// Matrix-Matrix Multiplication with Complex128 (std::complex<double>)
// tensors. For detailed info about parameters, look at FP32 function
// description.
void MklBlasGemm(bool transa, bool transb, const int m, const int n,
- const int k, const std::complex<double>* a, const int lda,
- const std::complex<double>* b, const int ldb,
- std::complex<double>* c, const int ldc) {
+ const int k, const complex128* a, const int lda,
+ const complex128* b, const int ldb, complex128* c,
+ const int ldc) {
const MKL_Complex16 alpha = {1.0, 0.0};
const MKL_Complex16 beta = {0.0, 0.0};
cblas_zgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
- transb ? CblasTrans : CblasNoTrans, m, n, k,
- static_cast<const void*>(&alpha), static_cast<const void*>(a),
- lda, static_cast<const void*>(b), ldb,
- static_cast<const void*>(&beta), static_cast<void*>(c), ldc);
+ transb ? CblasTrans : CblasNoTrans, m, n, k, &alpha,
+ reinterpret_cast<const MKL_Complex16*>(a), lda,
+ reinterpret_cast<const MKL_Complex16*>(b), ldb, &beta,
+ reinterpret_cast<MKL_Complex16*>(c), ldc);
}
};