aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/blas.h
diff options
context:
space:
mode:
authorGravatar Yangzihao Wang <yangzihao@google.com>2017-07-21 09:22:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-21 09:38:31 -0700
commit3e3306ef0009b5b21050139f9b8e5f4868c4c0c7 (patch)
treec7e25f278d93e9ce1ab9e2984df7b97c0f27c6d0 /tensorflow/stream_executor/blas.h
parent4729180d24af3126d736a7045c43fcbf031b5bef (diff)
Let GetBlasGemmAlgorithms() always return true.
PiperOrigin-RevId: 162748507
Diffstat (limited to 'tensorflow/stream_executor/blas.h')
-rw-r--r--tensorflow/stream_executor/blas.h138
1 files changed, 134 insertions, 4 deletions
diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h
index cfff3649c8..eb1b19c5d9 100644
--- a/tensorflow/stream_executor/blas.h
+++ b/tensorflow/stream_executor/blas.h
@@ -44,7 +44,6 @@ limitations under the License.
#include "tensorflow/stream_executor/platform/port.h"
#include "tensorflow/stream_executor/lib/array_slice.h"
-#include "tensorflow/stream_executor/platform/port.h"
namespace Eigen {
struct half;
@@ -108,6 +107,10 @@ string ComputationTypeString(ComputationType ty);
// Opaque identifier for an "algorithm" used by a blas routine. This functions
// as a hint to the blas library.
typedef int64 AlgorithmType;
+constexpr AlgorithmType kDefaultAlgorithm = -1;
+constexpr AlgorithmType kDefaultBlasGemm = -2;
+constexpr AlgorithmType kDefaultBlasGemv = -3;
+constexpr AlgorithmType kNoAlgorithm = -4;
// blas uses -1 to represent the default algorithm. This happens to match up
// with the CUBLAS_GEMM_DFALT constant, so cuda_blas.cc is using static_cast
@@ -134,10 +137,28 @@ class ProfileResult {
private:
bool is_valid_ = false;
- AlgorithmType algorithm_ = 0;
+ AlgorithmType algorithm_ = kDefaultAlgorithm;
float elapsed_time_in_ms_ = std::numeric_limits<float>::max();
};
+class AlgorithmConfig {
+ public:
+ AlgorithmConfig() : algorithm_(kDefaultAlgorithm) {}
+ explicit AlgorithmConfig(AlgorithmType algorithm) : algorithm_(algorithm) {}
+ AlgorithmType algorithm() const { return algorithm_; }
+ void set_algorithm(AlgorithmType val) { algorithm_ = val; }
+ bool operator==(const AlgorithmConfig &other) const {
+ return this->algorithm_ == other.algorithm_;
+ }
+ bool operator!=(const AlgorithmConfig &other) const {
+ return !(*this == other);
+ }
+ string ToString() const;
+
+ private:
+ AlgorithmType algorithm_;
+};
+
// BLAS support interface -- this can be derived from a GPU executor when the
// underlying platform has an BLAS library implementation available. See
// StreamExecutor::AsBlas().
@@ -453,6 +474,29 @@ class BlasSupport {
std::complex<double> beta,
DeviceMemory<std::complex<double>> *y, int incy) = 0;
+ virtual bool DoBlasGemvWithProfiling(
+ Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha,
+ const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,
+ int incx, float beta, DeviceMemory<float> *y, int incy,
+ ProfileResult *output_profile_result) = 0;
+ virtual bool DoBlasGemvWithProfiling(
+ Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha,
+ const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,
+ int incx, double beta, DeviceMemory<double> *y, int incy,
+ ProfileResult *output_profile_result) = 0;
+ virtual bool DoBlasGemvWithProfiling(
+ Stream *stream, blas::Transpose trans, uint64 m, uint64 n,
+ std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a,
+ int lda, const DeviceMemory<std::complex<float>> &x, int incx,
+ std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
+ ProfileResult *output_profile_result) = 0;
+ virtual bool DoBlasGemvWithProfiling(
+ Stream *stream, blas::Transpose trans, uint64 m, uint64 n,
+ std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a,
+ int lda, const DeviceMemory<std::complex<double>> &x, int incx,
+ std::complex<double> beta, DeviceMemory<std::complex<double>> *y,
+ int incy, ProfileResult *output_profile_result) = 0;
+
// Performs a rank-1 update of a general matrix.
//
// a <- alpha * x * y' + a,
@@ -935,8 +979,39 @@ class BlasSupport {
std::complex<double> beta,
DeviceMemory<std::complex<double>> *c, int ldc) = 0;
- // Gets a list of supported algorithms for DoBlasGemmWithAlgorithm. Note that
- // any or all of these algorithms may still be
+ virtual bool DoBlasGemmWithProfiling(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a,
+ int lda, const DeviceMemory<Eigen::half> &b, int ldb, float beta,
+ DeviceMemory<Eigen::half> *c, int ldc,
+ ProfileResult *output_profile_result) = 0;
+ virtual bool DoBlasGemmWithProfiling(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
+ const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
+ int ldc, ProfileResult *output_profile_result) = 0;
+ virtual bool DoBlasGemmWithProfiling(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
+ const DeviceMemory<double> &b, int ldb, double beta,
+ DeviceMemory<double> *c, int ldc,
+ ProfileResult *output_profile_result) = 0;
+ virtual bool DoBlasGemmWithProfiling(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &b, int ldb,
+ std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
+ ProfileResult *output_profile_result) = 0;
+ virtual bool DoBlasGemmWithProfiling(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &b, int ldb,
+ std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
+ ProfileResult *output_profile_result) = 0;
+
+ // Gets a list of supported algorithms for DoBlasGemmWithAlgorithm.
virtual bool GetBlasGemmAlgorithms(
std::vector<AlgorithmType> *out_algorithms) = 0;
@@ -1473,6 +1548,28 @@ class BlasSupport {
const DeviceMemory<std::complex<double>> &x, int incx, \
std::complex<double> beta, \
DeviceMemory<std::complex<double>> *y, int incy) override; \
+ bool DoBlasGemvWithProfiling( \
+ Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha, \
+ const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x, \
+ int incx, float beta, DeviceMemory<float> *y, int incy, \
+ blas::ProfileResult *output_profile_result) override; \
+ bool DoBlasGemvWithProfiling( \
+ Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha, \
+ const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x, \
+ int incx, double beta, DeviceMemory<double> *y, int incy, \
+ blas::ProfileResult *output_profile_result) override; \
+ bool DoBlasGemvWithProfiling( \
+ Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \
+ std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a, \
+ int lda, const DeviceMemory<std::complex<float>> &x, int incx, \
+ std::complex<float> beta, DeviceMemory<std::complex<float>> *y, \
+ int incy, blas::ProfileResult *output_profile_result) override; \
+ bool DoBlasGemvWithProfiling( \
+ Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \
+ std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a, \
+ int lda, const DeviceMemory<std::complex<double>> &x, int incx, \
+ std::complex<double> beta, DeviceMemory<std::complex<double>> *y, \
+ int incy, blas::ProfileResult *output_profile_result) override; \
bool DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha, \
const DeviceMemory<float> &x, int incx, \
const DeviceMemory<float> &y, int incy, \
@@ -1751,6 +1848,39 @@ class BlasSupport {
const DeviceMemory<std::complex<double>> &b, int ldb, \
std::complex<double> beta, \
DeviceMemory<std::complex<double>> *c, int ldc) override; \
+ bool DoBlasGemmWithProfiling( \
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, \
+ uint64 m, uint64 n, uint64 k, float alpha, \
+ const DeviceMemory<Eigen::half> &a, int lda, \
+ const DeviceMemory<Eigen::half> &b, int ldb, float beta, \
+ DeviceMemory<Eigen::half> *c, int ldc, \
+ blas::ProfileResult *output_profile_result) override; \
+ bool DoBlasGemmWithProfiling( \
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, \
+ uint64 m, uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, \
+ int lda, const DeviceMemory<float> &b, int ldb, float beta, \
+ DeviceMemory<float> *c, int ldc, \
+ blas::ProfileResult *output_profile_result) override; \
+ bool DoBlasGemmWithProfiling( \
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, \
+ uint64 m, uint64 n, uint64 k, double alpha, \
+ const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b, \
+ int ldb, double beta, DeviceMemory<double> *c, int ldc, \
+ blas::ProfileResult *output_profile_result) override; \
+ bool DoBlasGemmWithProfiling( \
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, \
+ uint64 m, uint64 n, uint64 k, std::complex<float> alpha, \
+ const DeviceMemory<std::complex<float>> &a, int lda, \
+ const DeviceMemory<std::complex<float>> &b, int ldb, \
+ std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, \
+ blas::ProfileResult *output_profile_result) override; \
+ bool DoBlasGemmWithProfiling( \
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, \
+ uint64 m, uint64 n, uint64 k, std::complex<double> alpha, \
+ const DeviceMemory<std::complex<double>> &a, int lda, \
+ const DeviceMemory<std::complex<double>> &b, int ldb, \
+ std::complex<double> beta, DeviceMemory<std::complex<double>> *c, \
+ int ldc, blas::ProfileResult *output_profile_result) override; \
bool GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> *out_algorithms) \
override; \
bool DoBlasGemmWithAlgorithm( \