aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/stream.h
diff options
context:
space:
mode:
authorGravatar Yangzihao Wang <yangzihao@google.com>2017-07-18 16:48:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-18 16:52:54 -0700
commit06acccabcb41513c76bbfffcd17817a7b136494b (patch)
tree82c2b379a5d8d4aa6cbbb653d3020ee2504bbd58 /tensorflow/stream_executor/stream.h
parent11dff5b05b3488520d3a415173d73ae91fded092 (diff)
Add autotuning code for matmul operator.
Currently it is turned off by default. PiperOrigin-RevId: 162423171
Diffstat (limited to 'tensorflow/stream_executor/stream.h')
-rw-r--r--tensorflow/stream_executor/stream.h63
1 files changed, 63 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h
index 3c8b7ee894..e218873839 100644
--- a/tensorflow/stream_executor/stream.h
+++ b/tensorflow/stream_executor/stream.h
@@ -934,6 +934,31 @@ class Stream {
std::complex<double> beta,
DeviceMemory<std::complex<double>> *y, int incy);
+ Stream &ThenBlasGemvWithProfiling(blas::Transpose trans, uint64 m, uint64 n,
+ float alpha, const DeviceMemory<float> &a,
+ int lda, const DeviceMemory<float> &x,
+ int incx, float beta,
+ DeviceMemory<float> *y, int incy,
+ blas::ProfileResult *output_profile_result);
+ Stream &ThenBlasGemvWithProfiling(blas::Transpose trans, uint64 m, uint64 n,
+ double alpha, const DeviceMemory<double> &a,
+ int lda, const DeviceMemory<double> &x,
+ int incx, double beta,
+ DeviceMemory<double> *y, int incy,
+ blas::ProfileResult *output_profile_result);
+ Stream &ThenBlasGemvWithProfiling(
+ blas::Transpose trans, uint64 m, uint64 n, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
+ blas::ProfileResult *output_profile_result);
+ Stream &ThenBlasGemvWithProfiling(
+ blas::Transpose trans, uint64 m, uint64 n, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ std::complex<double> beta, DeviceMemory<std::complex<double>> *y,
+ int incy, blas::ProfileResult *output_profile_result);
+
// See BlasSupport::DoBlasGer.
Stream &ThenBlasGer(uint64 m, uint64 n, float alpha,
const DeviceMemory<float> &x, int incx,
@@ -1249,6 +1274,44 @@ class Stream {
std::complex<double> beta,
DeviceMemory<std::complex<double>> *c, int ldc);
+ Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
+ blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, float alpha,
+ const DeviceMemory<Eigen::half> &a, int lda,
+ const DeviceMemory<Eigen::half> &b, int ldb,
+ float beta, DeviceMemory<Eigen::half> *c,
+ int ldc,
+ blas::ProfileResult *output_profile_result);
+ Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
+ blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, float alpha,
+ const DeviceMemory<float> &a, int lda,
+ const DeviceMemory<float> &b, int ldb,
+ float beta, DeviceMemory<float> *c, int ldc,
+ blas::ProfileResult *output_profile_result);
+ Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
+ blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, double alpha,
+ const DeviceMemory<double> &a, int lda,
+ const DeviceMemory<double> &b, int ldb,
+ double beta, DeviceMemory<double> *c,
+ int ldc,
+ blas::ProfileResult *output_profile_result);
+ Stream &ThenBlasGemmWithProfiling(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &b, int ldb,
+ std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
+ blas::ProfileResult *output_profile_result);
+ Stream &ThenBlasGemmWithProfiling(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &b, int ldb,
+ std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
+ blas::ProfileResult *output_profile_result);
+
// See BlasSupport::DoBlasGemmWithAlgorithm.
Stream &ThenBlasGemmWithAlgorithm(
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,