diff options
author | Yangzihao Wang <yangzihao@google.com> | 2017-07-21 09:22:31 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-07-21 09:38:31 -0700 |
commit | 3e3306ef0009b5b21050139f9b8e5f4868c4c0c7 (patch) | |
tree | c7e25f278d93e9ce1ab9e2984df7b97c0f27c6d0 /tensorflow/stream_executor/stream.cc | |
parent | 4729180d24af3126d736a7045c43fcbf031b5bef (diff) |
Let GetBlasGemmAlgorithms() always return true.
PiperOrigin-RevId: 162748507
Diffstat (limited to 'tensorflow/stream_executor/stream.cc')
-rw-r--r-- | tensorflow/stream_executor/stream.cc | 178 |
1 files changed, 178 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc index 5996195173..c9b36ba7ab 100644 --- a/tensorflow/stream_executor/stream.cc +++ b/tensorflow/stream_executor/stream.cc @@ -3458,6 +3458,184 @@ struct ThenBlasWithProfileImpl { }; } // anonymous namespace +Stream &Stream::ThenBlasGemvWithProfiling( + blas::Transpose trans, uint64 m, uint64 n, float alpha, + const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x, + int incx, float beta, DeviceMemory<float> *y, int incy, + blas::ProfileResult *output_profile_result) { + VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a), + PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), + PARAM(incy)); + + ThenBlasWithProfileImpl< + blas::Transpose, uint64, uint64, float, const DeviceMemory<float> &, int, + const DeviceMemory<float> &, int, float, DeviceMemory<float> *, int> + impl; + return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n, + alpha, a, lda, x, incx, beta, y, incy, output_profile_result); +} + +Stream &Stream::ThenBlasGemvWithProfiling( + blas::Transpose trans, uint64 m, uint64 n, double alpha, + const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x, + int incx, double beta, DeviceMemory<double> *y, int incy, + blas::ProfileResult *output_profile_result) { + VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a), + PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), + PARAM(incy)); + + ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, double, + const DeviceMemory<double> &, int, + const DeviceMemory<double> &, int, double, + DeviceMemory<double> *, int> + impl; + return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n, + alpha, a, lda, x, incx, beta, y, incy, output_profile_result); +} + +Stream &Stream::ThenBlasGemvWithProfiling( + blas::Transpose trans, uint64 m, uint64 n, std::complex<float> alpha, + const DeviceMemory<std::complex<float>> &a, int lda, + const DeviceMemory<std::complex<float>> &x, int incx, + std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy, + blas::ProfileResult *output_profile_result) { + VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a), + PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), + PARAM(incy)); + + ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, std::complex<float>, + const DeviceMemory<std::complex<float>> &, int, + const DeviceMemory<std::complex<float>> &, int, + std::complex<float>, + DeviceMemory<std::complex<float>> *, int> + impl; + return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n, + alpha, a, lda, x, incx, beta, y, incy, output_profile_result); +} + +Stream &Stream::ThenBlasGemvWithProfiling( + blas::Transpose trans, uint64 m, uint64 n, std::complex<double> alpha, + const DeviceMemory<std::complex<double>> &a, int lda, + const DeviceMemory<std::complex<double>> &x, int incx, + std::complex<double> beta, DeviceMemory<std::complex<double>> *y, int incy, + blas::ProfileResult *output_profile_result) { + VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a), + PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), + PARAM(incy)); + + ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, std::complex<double>, + const DeviceMemory<std::complex<double>> &, int, + const DeviceMemory<std::complex<double>> &, int, + std::complex<double>, + DeviceMemory<std::complex<double>> *, int> + impl; + return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n, + alpha, a, lda, x, incx, beta, y, incy, output_profile_result); +} + +Stream &Stream::ThenBlasGemmWithProfiling( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, int lda, + const DeviceMemory<Eigen::half> &b, int ldb, float beta, + DeviceMemory<Eigen::half> *c, int ldc, + blas::ProfileResult *output_profile_result) { + VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), + PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), + PARAM(beta), PARAM(c), PARAM(ldc)); + + ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64, + uint64, float, const DeviceMemory<Eigen::half> &, int, + const DeviceMemory<Eigen::half> &, int, float, + DeviceMemory<Eigen::half> *, int> + impl; + return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb, + m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + output_profile_result); +} + +Stream &Stream::ThenBlasGemmWithProfiling( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, float alpha, const DeviceMemory<float> &a, int lda, + const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c, + int ldc, blas::ProfileResult *output_profile_result) { + VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), + PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), + PARAM(beta), PARAM(c), PARAM(ldc)); + + ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64, + uint64, float, const DeviceMemory<float> &, int, + const DeviceMemory<float> &, int, float, + DeviceMemory<float> *, int> + impl; + return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb, + m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + output_profile_result); +} + +Stream &Stream::ThenBlasGemmWithProfiling( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, double alpha, const DeviceMemory<double> &a, int lda, + const DeviceMemory<double> &b, int ldb, double beta, + DeviceMemory<double> *c, int ldc, + blas::ProfileResult *output_profile_result) { + VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), + PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), + PARAM(beta), PARAM(c), PARAM(ldc)); + + ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64, + uint64, double, const DeviceMemory<double> &, int, + const DeviceMemory<double> &, int, double, + DeviceMemory<double> *, int> + impl; + return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb, + m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + output_profile_result); +} + +Stream &Stream::ThenBlasGemmWithProfiling( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, std::complex<float> alpha, + const DeviceMemory<std::complex<float>> &a, int lda, + const DeviceMemory<std::complex<float>> &b, int ldb, + std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, + blas::ProfileResult *output_profile_result) { + VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), + PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), + PARAM(beta), PARAM(c), PARAM(ldc)); + + ThenBlasWithProfileImpl< + blas::Transpose, blas::Transpose, uint64, uint64, uint64, + std::complex<float>, const DeviceMemory<std::complex<float>> &, int, + const DeviceMemory<std::complex<float>> &, int, std::complex<float>, + DeviceMemory<std::complex<float>> *, int> + impl; + return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb, + m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + output_profile_result); +} + +Stream &Stream::ThenBlasGemmWithProfiling( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, std::complex<double> alpha, + const DeviceMemory<std::complex<double>> &a, int lda, + const DeviceMemory<std::complex<double>> &b, int ldb, + std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc, + blas::ProfileResult *output_profile_result) { + VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), + PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), + PARAM(beta), PARAM(c), PARAM(ldc)); + + ThenBlasWithProfileImpl< + blas::Transpose, blas::Transpose, uint64, uint64, uint64, + std::complex<double>, const DeviceMemory<std::complex<double>> &, int, + const DeviceMemory<std::complex<double>> &, int, std::complex<double>, + DeviceMemory<std::complex<double>> *, int> + impl; + return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb, + m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + output_profile_result); +} + Stream &Stream::ThenBlasGemmWithAlgorithm( blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, const Eigen::half &alpha, const DeviceMemory<Eigen::half> &a, |