aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/blas.h
diff options
context:
space:
mode:
authorGravatar Justin Lebar <jlebar@google.com>2017-03-02 17:49:45 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-02 18:08:01 -0800
commit01194694948eb883e99af597d9dbbf3fc9f5c9e2 (patch)
treeab3517cf656259681283a90c6682c5b320ac36e3 /tensorflow/stream_executor/blas.h
parente065b3093f4fec5a5f79ad9de81f6baab361962e (diff)
[XLA] [StreamExecutor] Tune GEMMs when possible.
cublas 8 adds the cublasGemmEx function, which lets you specify an explicit "algorithm" for the computation. This functions as an opaque tuning hint to cublas. This patch adds support for cublasGemmEx to StreamExecutor, and wires up XLA's GemmThunk to use the new function. This patch does not add GEMM autotuning support in TensorFlow proper, only XLA. Change: 149068961
Diffstat (limited to 'tensorflow/stream_executor/blas.h')
-rw-r--r--tensorflow/stream_executor/blas.h145
1 files changed, 140 insertions, 5 deletions
diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h
index 853bbfd1c7..07a0f7ccd6 100644
--- a/tensorflow/stream_executor/blas.h
+++ b/tensorflow/stream_executor/blas.h
@@ -88,6 +88,46 @@ enum class Side { kLeft, kRight };
// Returns a name for s.
string SideString(Side s);
+// Type with which intermediate computations of a blas routine are performed.
+//
+// Some blas calls can perform computations with a type that's different than
+// the type of their inputs/outputs. This lets you e.g. multiply two matricies
+// of int8s using float32s to store the matmul's intermediate values.
+enum class ComputationType {
+ kF16, // 16-bit floating-point
+ kF32, // 32-bit floating-point
+ kF64, // 64-bit floating-point
+ kComplexF32, // Complex number comprised of two f32s.
+ kComplexF64 // Complex number comprised of two f64s.
+};
+
+// Converts a ComputationType to a string.
+string ComputationTypeString(ComputationType ty);
+
+// Opaque identifier for an "algorithm" used by a blas routine. This functions
+// as a hint to the blas library.
+typedef int64 AlgorithmType;
+
+// Describes the result of a performance experiment, usually timing the speed of
+// a particular AlgorithmType.
+//
+// If the call we were benchmarking failed (a common occurrence; not all
+// algorithms are valid for all calls), is_valid() will be false.
+class ProfileResult {
+ public:
+ bool is_valid() const { return is_valid_; }
+ void set_is_valid(bool val) { is_valid_ = val; }
+ AlgorithmType algorithm() const { return algorithm_; }
+ void set_algorithm(AlgorithmType val) { algorithm_ = val; }
+ float elapsed_time_in_ms() const { return elapsed_time_in_ms_; }
+ void set_elapsed_time_in_ms(float val) { elapsed_time_in_ms_ = val; }
+
+ private:
+ bool is_valid_ = false;
+ AlgorithmType algorithm_ = 0;
+ float elapsed_time_in_ms_ = std::numeric_limits<float>::max();
+};
+
// BLAS support interface -- this can be derived from a GPU executor when the
// underlying platform has an BLAS library implementation available. See
// StreamExecutor::AsBlas().
@@ -856,11 +896,10 @@ class BlasSupport {
// batched version of the half-precision interface.
virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa,
blas::Transpose transb, uint64 m, uint64 n, uint64 k,
- float alpha,
- const DeviceMemory<Eigen::half> &a, int lda,
- const DeviceMemory<Eigen::half> &b, int ldb,
- float beta,
- DeviceMemory<Eigen::half> *c, int ldc) = 0;
+ float alpha, const DeviceMemory<Eigen::half> &a,
+ int lda, const DeviceMemory<Eigen::half> &b, int ldb,
+ float beta, DeviceMemory<Eigen::half> *c,
+ int ldc) = 0;
virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa,
blas::Transpose transb, uint64 m, uint64 n, uint64 k,
float alpha, const DeviceMemory<float> &a, int lda,
@@ -886,6 +925,61 @@ class BlasSupport {
std::complex<double> beta,
DeviceMemory<std::complex<double>> *c, int ldc) = 0;
+ // Gets a list of supported algorithms for DoBlasGemmWithAlgorithm. Note that
+ // any or all of these algorithms may still be
+ virtual bool GetBlasGemmAlgorithms(
+ std::vector<AlgorithmType> *out_algorithms) = 0;
+
+ // Like DoBlasGemm, but accepts an algorithm and an compute type.
+ //
+ // The compute type lets you say (e.g.) that the inputs and outputs are
+ // Eigen::halfs, but you want the internal computations to be done with
+ // float32 precision.
+ //
+ // Note the subtle difference in the version that accepts Eigen:::half --
+ // alpha and beta have type const Eigen::half&, not float.
+ //
+ // If output_profile_result is not null, a failure here does not put the
+ // stream in a failure state. Instead, success/failure is indicated by
+ // output_profile_result->is_valid(). This lets you use this function for
+ // choosing the best algorithm among many (some of which may fail) without
+ // creating a new Stream for each attempt.
+ virtual bool DoBlasGemmWithAlgorithm(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, const Eigen::half &alpha,
+ const DeviceMemory<Eigen::half> &a, int lda,
+ const DeviceMemory<Eigen::half> &b, int ldb, const Eigen::half &beta,
+ DeviceMemory<Eigen::half> *c, int ldc, ComputationType computation_type,
+ AlgorithmType algorithm, ProfileResult *output_profile_result) = 0;
+ virtual bool DoBlasGemmWithAlgorithm(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
+ const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
+ int ldc, ComputationType computation_type, AlgorithmType algorithm,
+ ProfileResult *output_profile_result) = 0;
+ virtual bool DoBlasGemmWithAlgorithm(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
+ const DeviceMemory<double> &b, int ldb, double beta,
+ DeviceMemory<double> *c, int ldc, ComputationType computation_type,
+ AlgorithmType algorithm, ProfileResult *output_profile_result) = 0;
+ virtual bool DoBlasGemmWithAlgorithm(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &b, int ldb,
+ std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
+ ComputationType computation_type, AlgorithmType algorithm,
+ ProfileResult *output_profile_result) = 0;
+ virtual bool DoBlasGemmWithAlgorithm(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &b, int ldb,
+ std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
+ ComputationType computation_type, AlgorithmType algorithm,
+ ProfileResult *output_profile_result) = 0;
+
// Computes a batch of matrix-matrix product with general matrices.
// This is a batched version of DoBlasGemm.
// The batched GEMM computes matrix product for each input/output in a, b,
@@ -1641,6 +1735,47 @@ class BlasSupport {
const DeviceMemory<std::complex<double>> &b, int ldb, \
std::complex<double> beta, \
DeviceMemory<std::complex<double>> *c, int ldc) override; \
+ bool GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> *out_algorithms) \
+ override; \
+ bool DoBlasGemmWithAlgorithm( \
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, \
+ uint64 m, uint64 n, uint64 k, const Eigen::half &alpha, \
+ const DeviceMemory<Eigen::half> &a, int lda, \
+ const DeviceMemory<Eigen::half> &b, int ldb, const Eigen::half &beta, \
+ DeviceMemory<Eigen::half> *c, int ldc, \
+ blas::ComputationType computation_type, blas::AlgorithmType algorithm, \
+ blas::ProfileResult *output_profile_result) override; \
+ bool DoBlasGemmWithAlgorithm( \
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, \
+ uint64 m, uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, \
+ int lda, const DeviceMemory<float> &b, int ldb, float beta, \
+ DeviceMemory<float> *c, int ldc, blas::ComputationType computation_type, \
+ blas::AlgorithmType algorithm, \
+ blas::ProfileResult *output_profile_result) override; \
+ bool DoBlasGemmWithAlgorithm( \
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, \
+ uint64 m, uint64 n, uint64 k, double alpha, \
+ const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b, \
+ int ldb, double beta, DeviceMemory<double> *c, int ldc, \
+ blas::ComputationType computation_type, blas::AlgorithmType algorithm, \
+ blas::ProfileResult *output_profile_result) override; \
+ bool DoBlasGemmWithAlgorithm( \
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, \
+ uint64 m, uint64 n, uint64 k, std::complex<float> alpha, \
+ const DeviceMemory<std::complex<float>> &a, int lda, \
+ const DeviceMemory<std::complex<float>> &b, int ldb, \
+ std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, \
+ blas::ComputationType computation_type, blas::AlgorithmType algorithm, \
+ blas::ProfileResult *output_profile_result) override; \
+ bool DoBlasGemmWithAlgorithm( \
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, \
+ uint64 m, uint64 n, uint64 k, std::complex<double> alpha, \
+ const DeviceMemory<std::complex<double>> &a, int lda, \
+ const DeviceMemory<std::complex<double>> &b, int ldb, \
+ std::complex<double> beta, DeviceMemory<std::complex<double>> *c, \
+ int ldc, blas::ComputationType computation_type, \
+ blas::AlgorithmType algorithm, \
+ blas::ProfileResult *output_profile_result) override; \
bool DoBlasGemmBatched( \
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
uint64 m, uint64 n, uint64 k, float alpha, \