diff options
Diffstat (limited to 'tensorflow/core/kernels/mkl_matmul_op.cc')
-rw-r--r-- | tensorflow/core/kernels/mkl_matmul_op.cc | 28 |
1 files changed, 14 insertions, 14 deletions
diff --git a/tensorflow/core/kernels/mkl_matmul_op.cc b/tensorflow/core/kernels/mkl_matmul_op.cc index 47598f443f..dfa6cecc9b 100644 --- a/tensorflow/core/kernels/mkl_matmul_op.cc +++ b/tensorflow/core/kernels/mkl_matmul_op.cc @@ -170,32 +170,32 @@ class MklMatMulOp : public OpKernel { // Matrix-Matrix Multiplication with Complex64 (std::complex<float>) tensors. // For detailed info about parameters, look at FP32 function description. void MklBlasGemm(bool transa, bool transb, const int m, const int n, - const int k, const std::complex<float>* a, const int lda, - const std::complex<float>* b, const int ldb, - std::complex<float>* c, int const ldc) { + const int k, const complex64* a, const int lda, + const complex64* b, const int ldb, complex64* c, + int const ldc) { const MKL_Complex8 alpha = {1.0f, 0.0f}; const MKL_Complex8 beta = {0.0f, 0.0f}; cblas_cgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans, - transb ? CblasTrans : CblasNoTrans, m, n, k, - static_cast<const void*>(&alpha), static_cast<const void*>(a), - lda, static_cast<const void*>(b), ldb, - static_cast<const void*>(&beta), static_cast<void*>(c), ldc); + transb ? CblasTrans : CblasNoTrans, m, n, k, &alpha, + reinterpret_cast<const MKL_Complex8*>(a), lda, + reinterpret_cast<const MKL_Complex8*>(b), ldb, &beta, + reinterpret_cast<MKL_Complex8*>(c), ldc); } // Matrix-Matrix Multiplication with Complex128 (std::complex<double>) // tensors. For detailed info about parameters, look at FP32 function // description. void MklBlasGemm(bool transa, bool transb, const int m, const int n, - const int k, const std::complex<double>* a, const int lda, - const std::complex<double>* b, const int ldb, - std::complex<double>* c, const int ldc) { + const int k, const complex128* a, const int lda, + const complex128* b, const int ldb, complex128* c, + const int ldc) { const MKL_Complex16 alpha = {1.0, 0.0}; const MKL_Complex16 beta = {0.0, 0.0}; cblas_zgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans, - transb ? CblasTrans : CblasNoTrans, m, n, k, - static_cast<const void*>(&alpha), static_cast<const void*>(a), - lda, static_cast<const void*>(b), ldb, - static_cast<const void*>(&beta), static_cast<void*>(c), ldc); + transb ? CblasTrans : CblasNoTrans, m, n, k, &alpha, + reinterpret_cast<const MKL_Complex16*>(a), lda, + reinterpret_cast<const MKL_Complex16*>(b), ldb, &beta, + reinterpret_cast<MKL_Complex16*>(c), ldc); } }; |