diff options
author | Yangzihao Wang <yangzihao@google.com> | 2017-07-21 09:22:31 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-07-21 09:38:31 -0700 |
commit | 3e3306ef0009b5b21050139f9b8e5f4868c4c0c7 (patch) | |
tree | c7e25f278d93e9ce1ab9e2984df7b97c0f27c6d0 /tensorflow/stream_executor/blas.h | |
parent | 4729180d24af3126d736a7045c43fcbf031b5bef (diff) |
Let GetBlasGemmAlgorithms() always return true.
PiperOrigin-RevId: 162748507
Diffstat (limited to 'tensorflow/stream_executor/blas.h')
-rw-r--r-- | tensorflow/stream_executor/blas.h | 138 |
1 files changed, 134 insertions, 4 deletions
diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h index cfff3649c8..eb1b19c5d9 100644 --- a/tensorflow/stream_executor/blas.h +++ b/tensorflow/stream_executor/blas.h @@ -44,7 +44,6 @@ limitations under the License. #include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/lib/array_slice.h" -#include "tensorflow/stream_executor/platform/port.h" namespace Eigen { struct half; @@ -108,6 +107,10 @@ string ComputationTypeString(ComputationType ty); // Opaque identifier for an "algorithm" used by a blas routine. This functions // as a hint to the blas library. typedef int64 AlgorithmType; +constexpr AlgorithmType kDefaultAlgorithm = -1; +constexpr AlgorithmType kDefaultBlasGemm = -2; +constexpr AlgorithmType kDefaultBlasGemv = -3; +constexpr AlgorithmType kNoAlgorithm = -4; // blas uses -1 to represent the default algorithm. This happens to match up // with the CUBLAS_GEMM_DFALT constant, so cuda_blas.cc is using static_cast @@ -134,10 +137,28 @@ class ProfileResult { private: bool is_valid_ = false; - AlgorithmType algorithm_ = 0; + AlgorithmType algorithm_ = kDefaultAlgorithm; float elapsed_time_in_ms_ = std::numeric_limits<float>::max(); }; +class AlgorithmConfig { + public: + AlgorithmConfig() : algorithm_(kDefaultAlgorithm) {} + explicit AlgorithmConfig(AlgorithmType algorithm) : algorithm_(algorithm) {} + AlgorithmType algorithm() const { return algorithm_; } + void set_algorithm(AlgorithmType val) { algorithm_ = val; } + bool operator==(const AlgorithmConfig &other) const { + return this->algorithm_ == other.algorithm_; + } + bool operator!=(const AlgorithmConfig &other) const { + return !(*this == other); + } + string ToString() const; + + private: + AlgorithmType algorithm_; +}; + // BLAS support interface -- this can be derived from a GPU executor when the // underlying platform has an BLAS library implementation available. See // StreamExecutor::AsBlas(). @@ -453,6 +474,29 @@ class BlasSupport { std::complex<double> beta, DeviceMemory<std::complex<double>> *y, int incy) = 0; + virtual bool DoBlasGemvWithProfiling( + Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha, + const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x, + int incx, float beta, DeviceMemory<float> *y, int incy, + ProfileResult *output_profile_result) = 0; + virtual bool DoBlasGemvWithProfiling( + Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha, + const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x, + int incx, double beta, DeviceMemory<double> *y, int incy, + ProfileResult *output_profile_result) = 0; + virtual bool DoBlasGemvWithProfiling( + Stream *stream, blas::Transpose trans, uint64 m, uint64 n, + std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a, + int lda, const DeviceMemory<std::complex<float>> &x, int incx, + std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy, + ProfileResult *output_profile_result) = 0; + virtual bool DoBlasGemvWithProfiling( + Stream *stream, blas::Transpose trans, uint64 m, uint64 n, + std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a, + int lda, const DeviceMemory<std::complex<double>> &x, int incx, + std::complex<double> beta, DeviceMemory<std::complex<double>> *y, + int incy, ProfileResult *output_profile_result) = 0; + // Performs a rank-1 update of a general matrix. // // a <- alpha * x * y' + a, @@ -935,8 +979,39 @@ class BlasSupport { std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc) = 0; - // Gets a list of supported algorithms for DoBlasGemmWithAlgorithm. Note that - // any or all of these algorithms may still be + virtual bool DoBlasGemmWithProfiling( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, + int lda, const DeviceMemory<Eigen::half> &b, int ldb, float beta, + DeviceMemory<Eigen::half> *c, int ldc, + ProfileResult *output_profile_result) = 0; + virtual bool DoBlasGemmWithProfiling( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda, + const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c, + int ldc, ProfileResult *output_profile_result) = 0; + virtual bool DoBlasGemmWithProfiling( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda, + const DeviceMemory<double> &b, int ldb, double beta, + DeviceMemory<double> *c, int ldc, + ProfileResult *output_profile_result) = 0; + virtual bool DoBlasGemmWithProfiling( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, std::complex<float> alpha, + const DeviceMemory<std::complex<float>> &a, int lda, + const DeviceMemory<std::complex<float>> &b, int ldb, + std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, + ProfileResult *output_profile_result) = 0; + virtual bool DoBlasGemmWithProfiling( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, std::complex<double> alpha, + const DeviceMemory<std::complex<double>> &a, int lda, + const DeviceMemory<std::complex<double>> &b, int ldb, + std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc, + ProfileResult *output_profile_result) = 0; + + // Gets a list of supported algorithms for DoBlasGemmWithAlgorithm. virtual bool GetBlasGemmAlgorithms( std::vector<AlgorithmType> *out_algorithms) = 0; @@ -1473,6 +1548,28 @@ class BlasSupport { const DeviceMemory<std::complex<double>> &x, int incx, \ std::complex<double> beta, \ DeviceMemory<std::complex<double>> *y, int incy) override; \ + bool DoBlasGemvWithProfiling( \ + Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha, \ + const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x, \ + int incx, float beta, DeviceMemory<float> *y, int incy, \ + blas::ProfileResult *output_profile_result) override; \ + bool DoBlasGemvWithProfiling( \ + Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha, \ + const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x, \ + int incx, double beta, DeviceMemory<double> *y, int incy, \ + blas::ProfileResult *output_profile_result) override; \ + bool DoBlasGemvWithProfiling( \ + Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ + std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a, \ + int lda, const DeviceMemory<std::complex<float>> &x, int incx, \ + std::complex<float> beta, DeviceMemory<std::complex<float>> *y, \ + int incy, blas::ProfileResult *output_profile_result) override; \ + bool DoBlasGemvWithProfiling( \ + Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ + std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a, \ + int lda, const DeviceMemory<std::complex<double>> &x, int incx, \ + std::complex<double> beta, DeviceMemory<std::complex<double>> *y, \ + int incy, blas::ProfileResult *output_profile_result) override; \ bool DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha, \ const DeviceMemory<float> &x, int incx, \ const DeviceMemory<float> &y, int incy, \ @@ -1751,6 +1848,39 @@ class BlasSupport { const DeviceMemory<std::complex<double>> &b, int ldb, \ std::complex<double> beta, \ DeviceMemory<std::complex<double>> *c, int ldc) override; \ + bool DoBlasGemmWithProfiling( \ + Stream *stream, blas::Transpose transa, blas::Transpose transb, \ + uint64 m, uint64 n, uint64 k, float alpha, \ + const DeviceMemory<Eigen::half> &a, int lda, \ + const DeviceMemory<Eigen::half> &b, int ldb, float beta, \ + DeviceMemory<Eigen::half> *c, int ldc, \ + blas::ProfileResult *output_profile_result) override; \ + bool DoBlasGemmWithProfiling( \ + Stream *stream, blas::Transpose transa, blas::Transpose transb, \ + uint64 m, uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, \ + int lda, const DeviceMemory<float> &b, int ldb, float beta, \ + DeviceMemory<float> *c, int ldc, \ + blas::ProfileResult *output_profile_result) override; \ + bool DoBlasGemmWithProfiling( \ + Stream *stream, blas::Transpose transa, blas::Transpose transb, \ + uint64 m, uint64 n, uint64 k, double alpha, \ + const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b, \ + int ldb, double beta, DeviceMemory<double> *c, int ldc, \ + blas::ProfileResult *output_profile_result) override; \ + bool DoBlasGemmWithProfiling( \ + Stream *stream, blas::Transpose transa, blas::Transpose transb, \ + uint64 m, uint64 n, uint64 k, std::complex<float> alpha, \ + const DeviceMemory<std::complex<float>> &a, int lda, \ + const DeviceMemory<std::complex<float>> &b, int ldb, \ + std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, \ + blas::ProfileResult *output_profile_result) override; \ + bool DoBlasGemmWithProfiling( \ + Stream *stream, blas::Transpose transa, blas::Transpose transb, \ + uint64 m, uint64 n, uint64 k, std::complex<double> alpha, \ + const DeviceMemory<std::complex<double>> &a, int lda, \ + const DeviceMemory<std::complex<double>> &b, int ldb, \ + std::complex<double> beta, DeviceMemory<std::complex<double>> *c, \ + int ldc, blas::ProfileResult *output_profile_result) override; \ bool GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> *out_algorithms) \ override; \ bool DoBlasGemmWithAlgorithm( \ |