aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/blas.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-07-18 19:36:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-18 19:40:33 -0700
commit491beb74cc5a08693d0b884d10532514ac1aef19 (patch)
tree6ef4b12d84de7c922816ee46c873b58a9fc5e203 /tensorflow/stream_executor/blas.h
parent9293c557bd2df05658727418067ccee7a77a4be3 (diff)
Automated g4 rollback of changelist 162423171
PiperOrigin-RevId: 162437318
Diffstat (limited to 'tensorflow/stream_executor/blas.h')
-rw-r--r--tensorflow/stream_executor/blas.h138
1 files changed, 4 insertions, 134 deletions
diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h
index eb1b19c5d9..cfff3649c8 100644
--- a/tensorflow/stream_executor/blas.h
+++ b/tensorflow/stream_executor/blas.h
@@ -44,6 +44,7 @@ limitations under the License.
#include "tensorflow/stream_executor/platform/port.h"
#include "tensorflow/stream_executor/lib/array_slice.h"
+#include "tensorflow/stream_executor/platform/port.h"
namespace Eigen {
struct half;
@@ -107,10 +108,6 @@ string ComputationTypeString(ComputationType ty);
// Opaque identifier for an "algorithm" used by a blas routine. This functions
// as a hint to the blas library.
typedef int64 AlgorithmType;
-constexpr AlgorithmType kDefaultAlgorithm = -1;
-constexpr AlgorithmType kDefaultBlasGemm = -2;
-constexpr AlgorithmType kDefaultBlasGemv = -3;
-constexpr AlgorithmType kNoAlgorithm = -4;
// blas uses -1 to represent the default algorithm. This happens to match up
// with the CUBLAS_GEMM_DFALT constant, so cuda_blas.cc is using static_cast
@@ -137,28 +134,10 @@ class ProfileResult {
private:
bool is_valid_ = false;
- AlgorithmType algorithm_ = kDefaultAlgorithm;
+ AlgorithmType algorithm_ = 0;
float elapsed_time_in_ms_ = std::numeric_limits<float>::max();
};
-class AlgorithmConfig {
- public:
- AlgorithmConfig() : algorithm_(kDefaultAlgorithm) {}
- explicit AlgorithmConfig(AlgorithmType algorithm) : algorithm_(algorithm) {}
- AlgorithmType algorithm() const { return algorithm_; }
- void set_algorithm(AlgorithmType val) { algorithm_ = val; }
- bool operator==(const AlgorithmConfig &other) const {
- return this->algorithm_ == other.algorithm_;
- }
- bool operator!=(const AlgorithmConfig &other) const {
- return !(*this == other);
- }
- string ToString() const;
-
- private:
- AlgorithmType algorithm_;
-};
-
// BLAS support interface -- this can be derived from a GPU executor when the
// underlying platform has an BLAS library implementation available. See
// StreamExecutor::AsBlas().
@@ -474,29 +453,6 @@ class BlasSupport {
std::complex<double> beta,
DeviceMemory<std::complex<double>> *y, int incy) = 0;
- virtual bool DoBlasGemvWithProfiling(
- Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha,
- const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,
- int incx, float beta, DeviceMemory<float> *y, int incy,
- ProfileResult *output_profile_result) = 0;
- virtual bool DoBlasGemvWithProfiling(
- Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha,
- const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,
- int incx, double beta, DeviceMemory<double> *y, int incy,
- ProfileResult *output_profile_result) = 0;
- virtual bool DoBlasGemvWithProfiling(
- Stream *stream, blas::Transpose trans, uint64 m, uint64 n,
- std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a,
- int lda, const DeviceMemory<std::complex<float>> &x, int incx,
- std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
- ProfileResult *output_profile_result) = 0;
- virtual bool DoBlasGemvWithProfiling(
- Stream *stream, blas::Transpose trans, uint64 m, uint64 n,
- std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a,
- int lda, const DeviceMemory<std::complex<double>> &x, int incx,
- std::complex<double> beta, DeviceMemory<std::complex<double>> *y,
- int incy, ProfileResult *output_profile_result) = 0;
-
// Performs a rank-1 update of a general matrix.
//
// a <- alpha * x * y' + a,
@@ -979,39 +935,8 @@ class BlasSupport {
std::complex<double> beta,
DeviceMemory<std::complex<double>> *c, int ldc) = 0;
- virtual bool DoBlasGemmWithProfiling(
- Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
- uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a,
- int lda, const DeviceMemory<Eigen::half> &b, int ldb, float beta,
- DeviceMemory<Eigen::half> *c, int ldc,
- ProfileResult *output_profile_result) = 0;
- virtual bool DoBlasGemmWithProfiling(
- Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
- uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
- const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
- int ldc, ProfileResult *output_profile_result) = 0;
- virtual bool DoBlasGemmWithProfiling(
- Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
- uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
- const DeviceMemory<double> &b, int ldb, double beta,
- DeviceMemory<double> *c, int ldc,
- ProfileResult *output_profile_result) = 0;
- virtual bool DoBlasGemmWithProfiling(
- Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
- uint64 n, uint64 k, std::complex<float> alpha,
- const DeviceMemory<std::complex<float>> &a, int lda,
- const DeviceMemory<std::complex<float>> &b, int ldb,
- std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
- ProfileResult *output_profile_result) = 0;
- virtual bool DoBlasGemmWithProfiling(
- Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
- uint64 n, uint64 k, std::complex<double> alpha,
- const DeviceMemory<std::complex<double>> &a, int lda,
- const DeviceMemory<std::complex<double>> &b, int ldb,
- std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
- ProfileResult *output_profile_result) = 0;
-
- // Gets a list of supported algorithms for DoBlasGemmWithAlgorithm.
+ // Gets a list of supported algorithms for DoBlasGemmWithAlgorithm. Note that
+ // any or all of these algorithms may still be
virtual bool GetBlasGemmAlgorithms(
std::vector<AlgorithmType> *out_algorithms) = 0;
@@ -1548,28 +1473,6 @@ class BlasSupport {
const DeviceMemory<std::complex<double>> &x, int incx, \
std::complex<double> beta, \
DeviceMemory<std::complex<double>> *y, int incy) override; \
- bool DoBlasGemvWithProfiling( \
- Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha, \
- const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x, \
- int incx, float beta, DeviceMemory<float> *y, int incy, \
- blas::ProfileResult *output_profile_result) override; \
- bool DoBlasGemvWithProfiling( \
- Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha, \
- const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x, \
- int incx, double beta, DeviceMemory<double> *y, int incy, \
- blas::ProfileResult *output_profile_result) override; \
- bool DoBlasGemvWithProfiling( \
- Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \
- std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a, \
- int lda, const DeviceMemory<std::complex<float>> &x, int incx, \
- std::complex<float> beta, DeviceMemory<std::complex<float>> *y, \
- int incy, blas::ProfileResult *output_profile_result) override; \
- bool DoBlasGemvWithProfiling( \
- Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \
- std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a, \
- int lda, const DeviceMemory<std::complex<double>> &x, int incx, \
- std::complex<double> beta, DeviceMemory<std::complex<double>> *y, \
- int incy, blas::ProfileResult *output_profile_result) override; \
bool DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha, \
const DeviceMemory<float> &x, int incx, \
const DeviceMemory<float> &y, int incy, \
@@ -1848,39 +1751,6 @@ class BlasSupport {
const DeviceMemory<std::complex<double>> &b, int ldb, \
std::complex<double> beta, \
DeviceMemory<std::complex<double>> *c, int ldc) override; \
- bool DoBlasGemmWithProfiling( \
- Stream *stream, blas::Transpose transa, blas::Transpose transb, \
- uint64 m, uint64 n, uint64 k, float alpha, \
- const DeviceMemory<Eigen::half> &a, int lda, \
- const DeviceMemory<Eigen::half> &b, int ldb, float beta, \
- DeviceMemory<Eigen::half> *c, int ldc, \
- blas::ProfileResult *output_profile_result) override; \
- bool DoBlasGemmWithProfiling( \
- Stream *stream, blas::Transpose transa, blas::Transpose transb, \
- uint64 m, uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, \
- int lda, const DeviceMemory<float> &b, int ldb, float beta, \
- DeviceMemory<float> *c, int ldc, \
- blas::ProfileResult *output_profile_result) override; \
- bool DoBlasGemmWithProfiling( \
- Stream *stream, blas::Transpose transa, blas::Transpose transb, \
- uint64 m, uint64 n, uint64 k, double alpha, \
- const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b, \
- int ldb, double beta, DeviceMemory<double> *c, int ldc, \
- blas::ProfileResult *output_profile_result) override; \
- bool DoBlasGemmWithProfiling( \
- Stream *stream, blas::Transpose transa, blas::Transpose transb, \
- uint64 m, uint64 n, uint64 k, std::complex<float> alpha, \
- const DeviceMemory<std::complex<float>> &a, int lda, \
- const DeviceMemory<std::complex<float>> &b, int ldb, \
- std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, \
- blas::ProfileResult *output_profile_result) override; \
- bool DoBlasGemmWithProfiling( \
- Stream *stream, blas::Transpose transa, blas::Transpose transb, \
- uint64 m, uint64 n, uint64 k, std::complex<double> alpha, \
- const DeviceMemory<std::complex<double>> &a, int lda, \
- const DeviceMemory<std::complex<double>> &b, int ldb, \
- std::complex<double> beta, DeviceMemory<std::complex<double>> *c, \
- int ldc, blas::ProfileResult *output_profile_result) override; \
bool GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> *out_algorithms) \
override; \
bool DoBlasGemmWithAlgorithm( \