aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/service/gpu/gemm_thunk.cc231
-rw-r--r--tensorflow/compiler/xla/service/gpu/gemm_thunk.h8
-rw-r--r--tensorflow/stream_executor/blas.cc17
-rw-r--r--tensorflow/stream_executor/blas.h145
-rw-r--r--tensorflow/stream_executor/cuda/cuda_blas.cc231
-rw-r--r--tensorflow/stream_executor/cuda/cuda_blas.h36
-rw-r--r--tensorflow/stream_executor/stream.cc163
-rw-r--r--tensorflow/stream_executor/stream.h41
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.cc9
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.h3
10 files changed, 828 insertions, 56 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
index 46f316e844..a80f969b9d 100644
--- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
@@ -45,57 +45,163 @@ struct MatrixDescriptor {
int64 num_cols;
};
-// Performs a gemm call on lhs_matrix and rhs_matrix and stores the result to
-// output_matrix.
+// Performs a gemm call without an explicit algorithm on lhs_matrix and
+// rhs_matrix, and stores the result to output_matrix.
template <typename Element>
-tensorflow::Status DoGemm(MatrixDescriptor lhs_matrix,
- MatrixDescriptor rhs_matrix,
- MatrixDescriptor output_matrix, se::Stream* stream) {
+bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
+ MatrixDescriptor output_matrix, se::Stream* stream) {
DCHECK(!output_matrix.transpose);
se::DeviceMemory<Element> lhs_data(lhs_matrix.data);
se::DeviceMemory<Element> rhs_data(rhs_matrix.data);
se::DeviceMemory<Element> output_data(output_matrix.data);
- bool launch_ok =
- stream
- ->ThenBlasGemm(
- lhs_matrix.transpose ? se::blas::Transpose::kTranspose
- : se::blas::Transpose::kNoTranspose,
- rhs_matrix.transpose ? se::blas::Transpose::kTranspose
- : se::blas::Transpose::kNoTranspose,
- output_matrix.num_rows, output_matrix.num_cols,
- lhs_matrix.transpose
- ? lhs_matrix.num_rows
- : lhs_matrix.num_cols, // Size of the reduce dimension.
- /*alpha=*/1.0,
- lhs_data,
- lhs_matrix.num_rows, // The leading dimension of LHS.
- rhs_data,
- rhs_matrix.num_rows, // The leading dimension of RHS.
- /*beta=*/0.0, &output_data,
- output_matrix
- .num_rows) // The leading dimension of the output matrix.
- .ok();
- if (!launch_ok) {
- return InternalError("Unable to launch cuBLAS gemm on stream %p", stream);
+ auto lhs_transpose = lhs_matrix.transpose ? se::blas::Transpose::kTranspose
+ : se::blas::Transpose::kNoTranspose;
+ auto rhs_transpose = rhs_matrix.transpose ? se::blas::Transpose::kTranspose
+ : se::blas::Transpose::kNoTranspose;
+ auto k = lhs_matrix.transpose ? lhs_matrix.num_rows : lhs_matrix.num_cols;
+
+ return stream
+ ->ThenBlasGemm(
+ lhs_transpose, rhs_transpose, output_matrix.num_rows,
+ output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/1.0,
+ lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data,
+ /*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/0.0,
+ &output_data, /*leading dim of output=*/output_matrix.num_rows)
+ .ok();
+}
+
+// Like DoGemm, but takes an explicit computation type and algorithm.
+// computation_type specifies the type of intermediate values generated during
+// the matmul (e.g. your input/output matricies could be f16s but you could do
+// computations with f32s). algorithm is an opaque identifier which functions
+// as a hint to cublas.
+//
+// Not all algorithms are valid for all matrix sizes, and not all CUDA versions
+// and GPUs even support gemm-with-algorithm. So expect that this may fail
+// unless you've already checked that it works for this particular GPU + input
+// size.
+//
+// If you pass a non-null ProfileResult, this will always return true (assuming
+// the Stream was valid to begin with); check the is_valid property of the
+// ProfileResult to see whether the call actually succeeded.
+template <typename Element>
+bool DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix,
+ MatrixDescriptor rhs_matrix,
+ MatrixDescriptor output_matrix,
+ se::blas::ComputationType computation_type,
+ se::blas::AlgorithmType algorithm, se::Stream* stream,
+ se::blas::ProfileResult* output_profile_result) {
+ DCHECK(!output_matrix.transpose);
+
+ se::DeviceMemory<Element> lhs_data(lhs_matrix.data);
+ se::DeviceMemory<Element> rhs_data(rhs_matrix.data);
+ se::DeviceMemory<Element> output_data(output_matrix.data);
+
+ auto lhs_transpose = lhs_matrix.transpose ? se::blas::Transpose::kTranspose
+ : se::blas::Transpose::kNoTranspose;
+ auto rhs_transpose = rhs_matrix.transpose ? se::blas::Transpose::kTranspose
+ : se::blas::Transpose::kNoTranspose;
+ auto k = lhs_matrix.transpose ? lhs_matrix.num_rows : lhs_matrix.num_cols;
+
+ return stream
+ ->ThenBlasGemmWithAlgorithm(
+ lhs_transpose, rhs_transpose, output_matrix.num_rows,
+ output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/1.0,
+ lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data,
+ /*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/0.0,
+ &output_data, /*leading dim of output=*/output_matrix.num_rows,
+ computation_type, algorithm, output_profile_result)
+ .ok();
+}
+
+// Experimentally tries to pick the best algorithm for the given gemm.
+//
+// This may fail under perfectly normal circumstances. In particular, it will
+// fail if the program was built with < CUDA 8 or if we're using a gpu older
+// than sm_50 -- in both cases, cublas doesn't support gemm-with-algorithm at
+// all.
+template <typename Element>
+StatusOr<se::blas::AlgorithmType> DoGemmAutotune(
+ MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
+ MatrixDescriptor output_matrix, se::blas::ComputationType computation_type,
+ se::Stream* stream) {
+ std::vector<se::blas::AlgorithmType> algorithms;
+ CHECK(stream->parent()->GetBlasGemmAlgorithms(&algorithms));
+
+ se::blas::ProfileResult best_result;
+ for (auto algorithm : algorithms) {
+ se::blas::ProfileResult profile_result;
+ // We expect GemmWithAlgorithm to fail sometimes -- in fact, it will fail
+ // for all algorithms if we're targeting < sm_50. But because we pass a
+ // non-null ProfileResult, DoGemmWithAlgorithm should always return true,
+ // and the actual success-ness is returned in ProfileResult::is_valid.
+ DCHECK(DoGemmWithAlgorithm<Element>(lhs_matrix, rhs_matrix, output_matrix,
+ computation_type, algorithm, stream,
+ &profile_result));
+
+ if (profile_result.is_valid() && profile_result.elapsed_time_in_ms() <
+ best_result.elapsed_time_in_ms()) {
+ best_result = profile_result;
+ }
}
- return tensorflow::Status::OK();
+
+ if (best_result.is_valid()) {
+ return best_result.algorithm();
+ }
+
+ return InternalError(
+ "Unable to autotune cuBLAS gemm on stream %p; none of the %zu algorithms "
+ "ran successfully",
+ stream, algorithms.size());
}
-// Return, if the given type is a valid Gemm elemental type, the executor for
-// that type, else null.
-// TODO(b/27202055): consider more element types.
-std::function<tensorflow::Status(MatrixDescriptor, MatrixDescriptor,
- MatrixDescriptor, se::Stream*)>
-FindGemmExecutor(PrimitiveType type) {
+// Helper functions to go from a PrimitiveType to a templated version of
+// DoGemm/DoGemmWithAlgorithm/DoGemmAutotune.
+auto GetGemmFn(PrimitiveType type) -> decltype(&DoGemm<float>) {
switch (type) {
case F32:
return &DoGemm<float>;
case F64:
return &DoGemm<double>;
default:
- return nullptr;
+ LOG(FATAL) << "Unsupported type.";
+ }
+}
+auto GetGemmWithAlgorithmFn(PrimitiveType type)
+ -> decltype(&DoGemmWithAlgorithm<float>) {
+ switch (type) {
+ case F32:
+ return &DoGemmWithAlgorithm<float>;
+ case F64:
+ return &DoGemmWithAlgorithm<double>;
+ default:
+ LOG(FATAL) << "Unsupported type.";
+ }
+}
+auto GetGemmAutotuneFn(PrimitiveType type) -> decltype(&DoGemmAutotune<float>) {
+ switch (type) {
+ case F32:
+ return &DoGemmAutotune<float>;
+ case F64:
+ return &DoGemmAutotune<double>;
+ default:
+ LOG(FATAL) << "Unsupported type.";
+ }
+}
+
+// Converts from an XLA PrimitiveType to a blas::ComputationType, which is used
+// to specify the precision with which matmul computations should be performed,
+// separately from the precision of the inputs and result.
+se::blas::ComputationType GetBlasComputationType(PrimitiveType type) {
+ switch (type) {
+ case F32:
+ return se::blas::ComputationType::kF32;
+ case F64:
+ return se::blas::ComputationType::kF64;
+ default:
+ LOG(FATAL) << "Unsupported type.";
}
}
@@ -120,8 +226,6 @@ GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer,
tensorflow::Status GemmThunk::ExecuteOnStream(
const BufferAllocations& buffer_allocations, se::Stream* stream) {
VLOG(2) << "Executing a GemmThunk";
- auto executor = FindGemmExecutor(output_shape_.element_type());
- DCHECK(executor != nullptr);
se::DeviceMemoryBase lhs_data =
buffer_allocations.GetDeviceAddress(lhs_buffer_);
@@ -172,17 +276,66 @@ tensorflow::Status GemmThunk::ExecuteOnStream(
make_descriptor(lhs_data, lhs_shape_, transpose_lhs_);
const MatrixDescriptor rhs_descriptor =
make_descriptor(rhs_data, rhs_shape_, transpose_rhs_);
+
+ // Dispatches to a regular cublas gemm, a gemm-with-algorithm, or attempts to
+ // autotune this gemm to figure out the best algorithm.
+ auto launch = [this](MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
+ MatrixDescriptor output_matrix, se::Stream* stream) {
+ PrimitiveType element_type = output_shape_.element_type();
+ se::blas::ComputationType computation_type =
+ GetBlasComputationType(element_type);
+
+ const string& device_name = stream->parent()->GetDeviceDescription().name();
+ auto autotune_it = autotune_results_.find(device_name);
+ if (autotune_it == autotune_results_.end()) {
+ StatusOr<se::blas::AlgorithmType> best_algorithm =
+ GetGemmAutotuneFn(element_type)(lhs_matrix, rhs_matrix, output_matrix,
+ computation_type, stream);
+ autotune_it =
+ autotune_results_.insert({device_name, best_algorithm}).first;
+
+ if (autotune_it->second.ok()) {
+ VLOG(2) << "Autotune on GemmThunk " << this
+ << " successful; best algorithm is "
+ << best_algorithm.ValueOrDie();
+ } else {
+ VLOG(2) << "Autotune on GemmThunk " << this
+ << " unsuccessful. Will use generic gemm.";
+ }
+ }
+
+ const StatusOr<se::blas::AlgorithmType>& best_algorithm =
+ autotune_it->second;
+ if (best_algorithm.ok()) {
+ auto algorithm = best_algorithm.ValueOrDie();
+ VLOG(2) << "Using algorithm " << algorithm
+ << " chosen by autotuning on GemmThunk " << this;
+ return GetGemmWithAlgorithmFn(element_type)(
+ lhs_matrix, rhs_matrix, output_matrix, computation_type, algorithm,
+ stream,
+ /*output_profile_result=*/nullptr);
+ }
+ return GetGemmFn(element_type)(lhs_matrix, rhs_matrix, output_matrix,
+ stream);
+ };
+
+ bool launch_ok;
if (output_shape_.layout().minor_to_major(0) == 0) {
- return executor(
+ launch_ok = launch(
lhs_descriptor, rhs_descriptor,
MatrixDescriptor(output_data, false, output_num_rows, output_num_cols),
stream);
} else {
- return executor(
+ launch_ok = launch(
rhs_descriptor, lhs_descriptor,
MatrixDescriptor(output_data, false, output_num_cols, output_num_rows),
stream);
}
+
+ if (!launch_ok) {
+ return InternalError("Unable to launch cuBLAS gemm on stream %p", stream);
+ }
+ return tensorflow::Status::OK();
}
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h
index b540da65b4..983cb87292 100644
--- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h
@@ -63,6 +63,14 @@ class GemmThunk : public Thunk {
const bool transpose_lhs_;
const bool transpose_rhs_;
+
+ // Maps device names (StreamExecutor::DeviceDescription::name()) to autotune
+ // results. The map's value is the best algorithm we've found for this thunk
+ // on this device, or an error if none of the algorithms worked and we should
+ // use the regular gemm without an algorithm.
+ std::unordered_map<string,
+ StatusOr<::perftools::gputools::blas::AlgorithmType>>
+ autotune_results_;
};
} // namespace gpu
diff --git a/tensorflow/stream_executor/blas.cc b/tensorflow/stream_executor/blas.cc
index 239e3fce01..a59a1dda71 100644
--- a/tensorflow/stream_executor/blas.cc
+++ b/tensorflow/stream_executor/blas.cc
@@ -67,6 +67,23 @@ string SideString(Side s) {
}
}
+string ComputationTypeString(ComputationType ty) {
+ switch (ty) {
+ case ComputationType::kF16:
+ return "f16";
+ case ComputationType::kF32:
+ return "f32";
+ case ComputationType::kF64:
+ return "f64";
+ case ComputationType::kComplexF32:
+ return "complex f32";
+ case ComputationType::kComplexF64:
+ return "complex f64";
+ default:
+ LOG(FATAL) << "Unknown ComputationType " << static_cast<int32>(ty);
+ }
+}
+
} // namespace blas
} // namespace gputools
} // namespace perftools
diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h
index 853bbfd1c7..07a0f7ccd6 100644
--- a/tensorflow/stream_executor/blas.h
+++ b/tensorflow/stream_executor/blas.h
@@ -88,6 +88,46 @@ enum class Side { kLeft, kRight };
// Returns a name for s.
string SideString(Side s);
+// Type with which intermediate computations of a blas routine are performed.
+//
+// Some blas calls can perform computations with a type that's different than
+// the type of their inputs/outputs. This lets you e.g. multiply two matricies
+// of int8s using float32s to store the matmul's intermediate values.
+enum class ComputationType {
+ kF16, // 16-bit floating-point
+ kF32, // 32-bit floating-point
+ kF64, // 64-bit floating-point
+ kComplexF32, // Complex number comprised of two f32s.
+ kComplexF64 // Complex number comprised of two f64s.
+};
+
+// Converts a ComputationType to a string.
+string ComputationTypeString(ComputationType ty);
+
+// Opaque identifier for an "algorithm" used by a blas routine. This functions
+// as a hint to the blas library.
+typedef int64 AlgorithmType;
+
+// Describes the result of a performance experiment, usually timing the speed of
+// a particular AlgorithmType.
+//
+// If the call we were benchmarking failed (a common occurrence; not all
+// algorithms are valid for all calls), is_valid() will be false.
+class ProfileResult {
+ public:
+ bool is_valid() const { return is_valid_; }
+ void set_is_valid(bool val) { is_valid_ = val; }
+ AlgorithmType algorithm() const { return algorithm_; }
+ void set_algorithm(AlgorithmType val) { algorithm_ = val; }
+ float elapsed_time_in_ms() const { return elapsed_time_in_ms_; }
+ void set_elapsed_time_in_ms(float val) { elapsed_time_in_ms_ = val; }
+
+ private:
+ bool is_valid_ = false;
+ AlgorithmType algorithm_ = 0;
+ float elapsed_time_in_ms_ = std::numeric_limits<float>::max();
+};
+
// BLAS support interface -- this can be derived from a GPU executor when the
// underlying platform has an BLAS library implementation available. See
// StreamExecutor::AsBlas().
@@ -856,11 +896,10 @@ class BlasSupport {
// batched version of the half-precision interface.
virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa,
blas::Transpose transb, uint64 m, uint64 n, uint64 k,
- float alpha,
- const DeviceMemory<Eigen::half> &a, int lda,
- const DeviceMemory<Eigen::half> &b, int ldb,
- float beta,
- DeviceMemory<Eigen::half> *c, int ldc) = 0;
+ float alpha, const DeviceMemory<Eigen::half> &a,
+ int lda, const DeviceMemory<Eigen::half> &b, int ldb,
+ float beta, DeviceMemory<Eigen::half> *c,
+ int ldc) = 0;
virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa,
blas::Transpose transb, uint64 m, uint64 n, uint64 k,
float alpha, const DeviceMemory<float> &a, int lda,
@@ -886,6 +925,61 @@ class BlasSupport {
std::complex<double> beta,
DeviceMemory<std::complex<double>> *c, int ldc) = 0;
+ // Gets a list of supported algorithms for DoBlasGemmWithAlgorithm. Note that
+ // any or all of these algorithms may still be
+ virtual bool GetBlasGemmAlgorithms(
+ std::vector<AlgorithmType> *out_algorithms) = 0;
+
+ // Like DoBlasGemm, but accepts an algorithm and an compute type.
+ //
+ // The compute type lets you say (e.g.) that the inputs and outputs are
+ // Eigen::halfs, but you want the internal computations to be done with
+ // float32 precision.
+ //
+ // Note the subtle difference in the version that accepts Eigen:::half --
+ // alpha and beta have type const Eigen::half&, not float.
+ //
+ // If output_profile_result is not null, a failure here does not put the
+ // stream in a failure state. Instead, success/failure is indicated by
+ // output_profile_result->is_valid(). This lets you use this function for
+ // choosing the best algorithm among many (some of which may fail) without
+ // creating a new Stream for each attempt.
+ virtual bool DoBlasGemmWithAlgorithm(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, const Eigen::half &alpha,
+ const DeviceMemory<Eigen::half> &a, int lda,
+ const DeviceMemory<Eigen::half> &b, int ldb, const Eigen::half &beta,
+ DeviceMemory<Eigen::half> *c, int ldc, ComputationType computation_type,
+ AlgorithmType algorithm, ProfileResult *output_profile_result) = 0;
+ virtual bool DoBlasGemmWithAlgorithm(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
+ const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
+ int ldc, ComputationType computation_type, AlgorithmType algorithm,
+ ProfileResult *output_profile_result) = 0;
+ virtual bool DoBlasGemmWithAlgorithm(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
+ const DeviceMemory<double> &b, int ldb, double beta,
+ DeviceMemory<double> *c, int ldc, ComputationType computation_type,
+ AlgorithmType algorithm, ProfileResult *output_profile_result) = 0;
+ virtual bool DoBlasGemmWithAlgorithm(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &b, int ldb,
+ std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
+ ComputationType computation_type, AlgorithmType algorithm,
+ ProfileResult *output_profile_result) = 0;
+ virtual bool DoBlasGemmWithAlgorithm(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &b, int ldb,
+ std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
+ ComputationType computation_type, AlgorithmType algorithm,
+ ProfileResult *output_profile_result) = 0;
+
// Computes a batch of matrix-matrix product with general matrices.
// This is a batched version of DoBlasGemm.
// The batched GEMM computes matrix product for each input/output in a, b,
@@ -1641,6 +1735,47 @@ class BlasSupport {
const DeviceMemory<std::complex<double>> &b, int ldb, \
std::complex<double> beta, \
DeviceMemory<std::complex<double>> *c, int ldc) override; \
+ bool GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> *out_algorithms) \
+ override; \
+ bool DoBlasGemmWithAlgorithm( \
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, \
+ uint64 m, uint64 n, uint64 k, const Eigen::half &alpha, \
+ const DeviceMemory<Eigen::half> &a, int lda, \
+ const DeviceMemory<Eigen::half> &b, int ldb, const Eigen::half &beta, \
+ DeviceMemory<Eigen::half> *c, int ldc, \
+ blas::ComputationType computation_type, blas::AlgorithmType algorithm, \
+ blas::ProfileResult *output_profile_result) override; \
+ bool DoBlasGemmWithAlgorithm( \
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, \
+ uint64 m, uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, \
+ int lda, const DeviceMemory<float> &b, int ldb, float beta, \
+ DeviceMemory<float> *c, int ldc, blas::ComputationType computation_type, \
+ blas::AlgorithmType algorithm, \
+ blas::ProfileResult *output_profile_result) override; \
+ bool DoBlasGemmWithAlgorithm( \
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, \
+ uint64 m, uint64 n, uint64 k, double alpha, \
+ const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b, \
+ int ldb, double beta, DeviceMemory<double> *c, int ldc, \
+ blas::ComputationType computation_type, blas::AlgorithmType algorithm, \
+ blas::ProfileResult *output_profile_result) override; \
+ bool DoBlasGemmWithAlgorithm( \
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, \
+ uint64 m, uint64 n, uint64 k, std::complex<float> alpha, \
+ const DeviceMemory<std::complex<float>> &a, int lda, \
+ const DeviceMemory<std::complex<float>> &b, int ldb, \
+ std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, \
+ blas::ComputationType computation_type, blas::AlgorithmType algorithm, \
+ blas::ProfileResult *output_profile_result) override; \
+ bool DoBlasGemmWithAlgorithm( \
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, \
+ uint64 m, uint64 n, uint64 k, std::complex<double> alpha, \
+ const DeviceMemory<std::complex<double>> &a, int lda, \
+ const DeviceMemory<std::complex<double>> &b, int ldb, \
+ std::complex<double> beta, DeviceMemory<std::complex<double>> *c, \
+ int ldc, blas::ComputationType computation_type, \
+ blas::AlgorithmType algorithm, \
+ blas::ProfileResult *output_profile_result) override; \
bool DoBlasGemmBatched( \
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
uint64 m, uint64 n, uint64 k, float alpha, \
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc
index 3df366bc4d..2c650afc70 100644
--- a/tensorflow/stream_executor/cuda/cuda_blas.cc
+++ b/tensorflow/stream_executor/cuda/cuda_blas.cc
@@ -40,6 +40,7 @@ limitations under the License.
#include "tensorflow/stream_executor/cuda/cuda_helpers.h"
#include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
#include "tensorflow/stream_executor/cuda/cuda_stream.h"
+#include "tensorflow/stream_executor/cuda/cuda_timer.h"
#include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/lib/env.h"
#include "tensorflow/stream_executor/lib/initialize.h"
@@ -262,6 +263,10 @@ CUBLAS_BLAS_ROUTINE_EACH(PERFTOOLS_GPUTOOLS_CUBLAS_V2_WRAP)
PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasSgemmEx)
#endif
+#if CUDA_VERSION >= 8000
+PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasGemmEx)
+#endif
+
} // namespace wrap
static string ToString(cublasStatus_t status) {
@@ -282,6 +287,12 @@ static string ToString(cublasStatus_t status) {
return "CUBLAS_STATUS_EXECUTION_FAILED";
case CUBLAS_STATUS_INTERNAL_ERROR:
return "CUBLAS_STATUS_INTERNAL_ERROR";
+#if CUDA_VERSION >= 8000
+ case CUBLAS_STATUS_NOT_SUPPORTED:
+ return "CUBLAS_STATUS_NOT_SUPPORTED";
+ case CUBLAS_STATUS_LICENSE_ERROR:
+ return "CUBLAS_STATUS_LICENSE_ERROR";
+#endif
default:
return port::StrCat("<invalid cublas status: ", status, ">");
}
@@ -431,11 +442,89 @@ cublasSideMode_t CUDABlasSide(blas::Side side) {
}
}
+// CUDADataType<T>::type translates from a C++ type (e.g. float) to a
+// cudaDataType_t (e.g. CUDA_R_32F). CUDAComputationType(ty) translates from a
+// blas::ComputationType to a cudaDataType_t.
+//
+// These are used to build the argument type and computation type args to
+// cublasGemmEx. cublasGemmEx and cudaDataType_t are available only on
+// CUDA >= 8.0.
+#if CUDA_VERSION >= 8000
+template <typename T>
+struct CUDADataType;
+
+template <>
+struct CUDADataType<Eigen::half> {
+ static constexpr cudaDataType_t type = SE_CUDA_DATA_HALF;
+};
+
+template <>
+struct CUDADataType<std::complex<Eigen::half>> {
+ static constexpr cudaDataType_t type = CUDA_C_16F;
+};
+
+template <>
+struct CUDADataType<float> {
+ static constexpr cudaDataType_t type = CUDA_R_32F;
+};
+
+template <>
+struct CUDADataType<std::complex<float>> {
+ static constexpr cudaDataType_t type = CUDA_C_32F;
+};
+
+template <>
+struct CUDADataType<double> {
+ static constexpr cudaDataType_t type = CUDA_R_64F;
+};
+
+template <>
+struct CUDADataType<std::complex<double>> {
+ static constexpr cudaDataType_t type = CUDA_C_64F;
+};
+
+template <>
+struct CUDADataType<int8> {
+ static constexpr cudaDataType_t type = CUDA_R_8I;
+};
+
+template <>
+struct CUDADataType<std::complex<int8>> {
+ static constexpr cudaDataType_t type = CUDA_C_8I;
+};
+
+template <>
+struct CUDADataType<uint8> {
+ static constexpr cudaDataType_t type = CUDA_R_8U;
+};
+
+template <>
+struct CUDADataType<std::complex<uint8>> {
+ static constexpr cudaDataType_t type = CUDA_C_8U;
+};
+
+cudaDataType_t CUDAComputationType(blas::ComputationType ty) {
+ switch (ty) {
+ case blas::ComputationType::kF16:
+ return CUDA_R_16F;
+ case blas::ComputationType::kF32:
+ return CUDA_R_32F;
+ case blas::ComputationType::kF64:
+ return CUDA_R_64F;
+ case blas::ComputationType::kComplexF32:
+ return CUDA_C_32F;
+ case blas::ComputationType::kComplexF64:
+ return CUDA_C_64F;
+ }
+}
+#endif
+
} // namespace
template <typename FuncT, typename... Args>
-bool CUDABlas::DoBlasInternal(FuncT cublas_func, Stream *stream,
- bool pointer_mode_host, Args... args) {
+bool CUDABlas::DoBlasInternalImpl(FuncT cublas_func, Stream *stream,
+ bool pointer_mode_host, bool err_on_failure,
+ Args... args) {
mutex_lock lock{mu_};
CHECK(blas_ != nullptr);
@@ -450,13 +539,11 @@ bool CUDABlas::DoBlasInternal(FuncT cublas_func, Stream *stream,
}
cublasStatus_t ret = cublas_func(parent_, blas_, args...);
- if (ret != CUBLAS_STATUS_SUCCESS) {
+ if (err_on_failure && ret != CUBLAS_STATUS_SUCCESS) {
LOG(ERROR) << "failed to run cuBLAS routine " << cublas_func.kName << ": "
<< ToString(ret);
- return false;
}
-
- return true;
+ return ret == CUBLAS_STATUS_SUCCESS;
}
bool CUDABlas::DoBlasAsum(Stream *stream, uint64 elem_count,
@@ -1762,6 +1849,138 @@ bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
CUDAComplex(CUDAMemoryMutable(c)), ldc);
}
+template <typename T>
+bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, const T &alpha, const DeviceMemory<T> &a, int lda,
+ const DeviceMemory<T> &b, int ldb, const T &beta, DeviceMemory<T> *c,
+ int ldc, blas::ComputationType computation_type,
+ blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
+// CUDA < version 8 and GPUs < sm_50 don't support cublasGemmEx.
+#if CUDA_VERSION < 8000
+ return false;
+#else
+ int cc_major, cc_minor;
+ if (stream->parent()->GetDeviceDescription().cuda_compute_capability(
+ &cc_major, &cc_minor) &&
+ cc_major < 5) {
+ return false;
+ }
+
+ struct TimerDeleter {
+ void operator()(CUDATimer *t) {
+ t->Destroy();
+ delete t;
+ }
+ };
+ std::unique_ptr<CUDATimer, TimerDeleter> timer;
+ if (output_profile_result != nullptr) {
+ timer.reset(new CUDATimer(parent_));
+ if (!timer->Init() || !timer->Start(AsCUDAStream(stream))) {
+ return false;
+ }
+ }
+
+ cudaDataType_t data_type = CUDADataType<T>::type;
+ bool result = DoBlasInternalFailureOK(
+ wrap::cublasGemmEx, stream, /* pointer_mode_host = */ true,
+ CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha,
+ CUDAMemory(a), data_type, lda, CUDAMemory(b), data_type, ldb, &beta,
+ CUDAMemoryMutable(c), data_type, ldc,
+ CUDAComputationType(computation_type),
+ static_cast<cublasGemmAlgo_t>(algorithm));
+
+ if (timer != nullptr && result) {
+ // CUDATimer will CHECK-fail if we Stop() it while the stream is in an error
+ // state.
+ if (!timer->Stop(AsCUDAStream(stream))) {
+ return false;
+ }
+ output_profile_result->set_is_valid(true);
+ output_profile_result->set_algorithm(algorithm);
+ output_profile_result->set_elapsed_time_in_ms(
+ timer->GetElapsedMilliseconds());
+ }
+ return result;
+#endif
+}
+
+bool CUDABlas::GetBlasGemmAlgorithms(
+ std::vector<blas::AlgorithmType> *out_algorithms) {
+// cublasGemmAlgo_t (and the function that accepts this type, cublasGemmEx)
+// were first introduced in CUDA 8.
+#if CUDA_VERSION >= 8000
+ for (cublasGemmAlgo_t algo :
+ {CUBLAS_GEMM_DFALT, CUBLAS_GEMM_ALGO0, CUBLAS_GEMM_ALGO1,
+ CUBLAS_GEMM_ALGO2, CUBLAS_GEMM_ALGO3, CUBLAS_GEMM_ALGO4,
+ CUBLAS_GEMM_ALGO5, CUBLAS_GEMM_ALGO6, CUBLAS_GEMM_ALGO7}) {
+ out_algorithms->push_back(algo);
+ }
+#endif
+ return true;
+}
+
+bool CUDABlas::DoBlasGemmWithAlgorithm(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, const Eigen::half &alpha,
+ const DeviceMemory<Eigen::half> &a, int lda,
+ const DeviceMemory<Eigen::half> &b, int ldb, const Eigen::half &beta,
+ DeviceMemory<Eigen::half> *c, int ldc,
+ blas::ComputationType computation_type, blas::AlgorithmType algorithm,
+ blas::ProfileResult *output_profile_result) {
+ return DoBlasGemmWithAlgorithmImpl(
+ stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
+ computation_type, algorithm, output_profile_result);
+}
+
+bool CUDABlas::DoBlasGemmWithAlgorithm(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
+ const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
+ int ldc, blas::ComputationType computation_type,
+ blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
+ return DoBlasGemmWithAlgorithmImpl(
+ stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
+ computation_type, algorithm, output_profile_result);
+}
+
+bool CUDABlas::DoBlasGemmWithAlgorithm(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
+ const DeviceMemory<double> &b, int ldb, double beta,
+ DeviceMemory<double> *c, int ldc, blas::ComputationType computation_type,
+ blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
+ return DoBlasGemmWithAlgorithmImpl(
+ stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
+ computation_type, algorithm, output_profile_result);
+}
+
+bool CUDABlas::DoBlasGemmWithAlgorithm(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &b, int ldb,
+ std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
+ blas::ComputationType computation_type, blas::AlgorithmType algorithm,
+ blas::ProfileResult *output_profile_result) {
+ return DoBlasGemmWithAlgorithmImpl(
+ stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
+ computation_type, algorithm, output_profile_result);
+}
+
+bool CUDABlas::DoBlasGemmWithAlgorithm(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &b, int ldb,
+ std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
+ blas::ComputationType computation_type, blas::AlgorithmType algorithm,
+ blas::ProfileResult *output_profile_result) {
+ return DoBlasGemmWithAlgorithmImpl(
+ stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
+ computation_type, algorithm, output_profile_result);
+}
+
template <typename T, typename FuncT>
port::Status CUDABlas::DoBlasGemmBatchedInternal(
FuncT cublas_func, Stream *stream, blas::Transpose transa,
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.h b/tensorflow/stream_executor/cuda/cuda_blas.h
index 1226df6d65..6a33cd746b 100644
--- a/tensorflow/stream_executor/cuda/cuda_blas.h
+++ b/tensorflow/stream_executor/cuda/cuda_blas.h
@@ -79,10 +79,27 @@ class CUDABlas : public blas::BlasSupport {
// stream: Stream to enqueue the BLAS operation onto.
// pointer_mode_host: Indicate if the pointer to a scalar value is from host
// (true) or device (false).
+ // err_on_failure: Whether to print an error if the cublas function fails.
// args: Arguments of cuBLAS function.
template <typename FuncT, typename... Args>
+ bool DoBlasInternalImpl(FuncT cublas_func, Stream *stream,
+ bool pointer_mode_host, bool err_on_failure,
+ Args... args);
+
+ // Convenience functions that call DoBlasInternalImpl with different values
+ // for err_on_failure.
+ template <typename FuncT, typename... Args>
bool DoBlasInternal(FuncT cublas_func, Stream *stream, bool pointer_mode_host,
- Args... args);
+ Args... args) {
+ return DoBlasInternalImpl(cublas_func, stream, pointer_mode_host,
+ /*err_on_failure=*/true, args...);
+ }
+ template <typename FuncT, typename... Args>
+ bool DoBlasInternalFailureOK(FuncT cublas_func, Stream *stream,
+ bool pointer_mode_host, Args... args) {
+ return DoBlasInternalImpl(cublas_func, stream, pointer_mode_host,
+ /*err_on_failure=*/false, args...);
+ }
// A helper function to implement DoBlasGemmBatched interfaces for generic
// types.
@@ -95,6 +112,23 @@ class CUDABlas : public blas::BlasSupport {
const port::ArraySlice<DeviceMemory<T> *> &c_array, int ldc,
int batch_count, ScratchAllocator *scratch_allocator);
+ // Helper function for implementing DoBlasGemmWithAlgorithm.
+ //
+ // We take alpha and beta by const reference because T might be Eigen::half,
+ // and we want to avoid pulling in a dependency on Eigen. When we pass the
+ // references to cublas, we essentially reinterpret_cast to __half, which is
+ // safe because Eigen::half inherits from __half.
+ template <typename T>
+ bool DoBlasGemmWithAlgorithmImpl(Stream *stream, blas::Transpose transa,
+ blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, const T &alpha,
+ const DeviceMemory<T> &a, int lda,
+ const DeviceMemory<T> &b, int ldb,
+ const T &beta, DeviceMemory<T> *c, int ldc,
+ blas::ComputationType computation_type,
+ blas::AlgorithmType algorithm,
+ blas::ProfileResult *output_profile_result);
+
// mutex that guards the cuBLAS handle for this device.
mutex mu_;
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc
index 9b67531689..f28f965f2c 100644
--- a/tensorflow/stream_executor/stream.cc
+++ b/tensorflow/stream_executor/stream.cc
@@ -75,6 +75,10 @@ string ToVlogString(blas::Diagonal d) { return blas::DiagonalString(d); }
string ToVlogString(blas::Side s) { return blas::SideString(s); }
+string ToVlogString(blas::ComputationType ty) {
+ return blas::ComputationTypeString(ty);
+}
+
string ToVlogString(const void *ptr) {
if (ptr == nullptr) {
return "null";
@@ -109,6 +113,8 @@ string ToVlogString(const DeviceMemoryBase *memory) {
return ToVlogString(*memory);
}
+string ToVlogString(const Eigen::half &h) { return port::StrCat(h); }
+
string ToVlogString(int i) { return port::StrCat(i); }
string ToVlogString(uint32 i) { return port::StrCat(i); }
@@ -1520,21 +1526,33 @@ struct ThenBlasImpl {
// arguments except the first one of Stream* type.
Stream &operator()(Stream *stream,
bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
- Args... args);
+ Args... args) {
+ return Run(stream, blas_func, /*record_error=*/true, args...);
+ }
+
+ // Like operator(), but only calls stream->CheckError() if record_error is
+ // true.
+ Stream &Run(Stream *stream,
+ bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
+ bool record_error, Args... args);
};
template <typename... Args>
-Stream &ThenBlasImpl<Args...>::operator()(
+Stream &ThenBlasImpl<Args...>::Run(
Stream *stream, bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
- Args... args) {
+ bool record_error, Args... args) {
if (stream->ok()) {
+ bool ok;
if (blas::BlasSupport *blas = stream->parent_->AsBlas()) {
- stream->CheckError((blas->*blas_func)(stream, args...));
+ ok = (blas->*blas_func)(stream, args...);
} else {
- stream->CheckError(false);
LOG(WARNING)
<< "attempting to perform BLAS operation using StreamExecutor "
"without BLAS support";
+ ok = false;
+ }
+ if (record_error) {
+ stream->CheckError(ok);
}
}
return *stream;
@@ -3215,6 +3233,141 @@ Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
alpha, a, lda, b, ldb, beta, c, ldc);
}
+namespace {
+// Like ThenBlasImpl, except this expects the last argument of blas_func to be a
+// blas::ProfileResult*. This functor doesn't put the stream into an error
+// state if the op fails and the profile result is non-null. Instead, the
+// error-ness is returned in the profile result itself.
+template <typename... Args>
+struct ThenBlasWithProfileImpl {
+ Stream &operator()(Stream *stream,
+ bool (blas::BlasSupport::*blas_func)(
+ Stream *, Args..., blas::ProfileResult *),
+ Args... args, blas::ProfileResult *profile_result) {
+ ThenBlasImpl<Args..., blas::ProfileResult *> Runner;
+ bool record_error = profile_result == nullptr;
+ return Runner.Run(stream, blas_func, record_error, args..., profile_result);
+ }
+};
+} // anonymous namespace
+
+Stream &Stream::ThenBlasGemmWithAlgorithm(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, const Eigen::half &alpha, const DeviceMemory<Eigen::half> &a,
+ int lda, const DeviceMemory<Eigen::half> &b, int ldb,
+ const Eigen::half &beta, DeviceMemory<Eigen::half> *c, int ldc,
+ blas::ComputationType computation_type, blas::AlgorithmType algorithm,
+ blas::ProfileResult *output_profile_result) {
+ VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
+ PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
+ PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
+ PARAM(algorithm));
+
+ ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
+ uint64, const Eigen::half &,
+ const DeviceMemory<Eigen::half> &, int,
+ const DeviceMemory<Eigen::half> &, int,
+ const Eigen::half &, DeviceMemory<Eigen::half> *, int,
+ blas::ComputationType, blas::AlgorithmType>
+ impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
+ m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
+ algorithm, output_profile_result);
+}
+
+Stream &Stream::ThenBlasGemmWithAlgorithm(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
+ const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
+ int ldc, blas::ComputationType computation_type,
+ blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
+ VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
+ PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
+ PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
+ PARAM(algorithm));
+
+ ThenBlasWithProfileImpl<
+ blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
+ const DeviceMemory<float> &, int, const DeviceMemory<float> &, int, float,
+ DeviceMemory<float> *, int, blas::ComputationType, blas::AlgorithmType>
+ impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
+ m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
+ algorithm, output_profile_result);
+}
+
+Stream &Stream::ThenBlasGemmWithAlgorithm(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
+ const DeviceMemory<double> &b, int ldb, double beta,
+ DeviceMemory<double> *c, int ldc, blas::ComputationType computation_type,
+ blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
+ VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
+ PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
+ PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
+ PARAM(algorithm));
+
+ ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
+ uint64, double, const DeviceMemory<double> &, int,
+ const DeviceMemory<double> &, int, double,
+ DeviceMemory<double> *, int, blas::ComputationType,
+ blas::AlgorithmType>
+ impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
+ m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
+ algorithm, output_profile_result);
+}
+
+Stream &Stream::ThenBlasGemmWithAlgorithm(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &b, int ldb,
+ std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
+ blas::ComputationType computation_type, blas::AlgorithmType algorithm,
+ blas::ProfileResult *output_profile_result) {
+ VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
+ PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
+ PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
+ PARAM(algorithm));
+
+ ThenBlasWithProfileImpl<
+ blas::Transpose, blas::Transpose, uint64, uint64, uint64,
+ std::complex<float>, const DeviceMemory<std::complex<float>> &, int,
+ const DeviceMemory<std::complex<float>> &, int, std::complex<float>,
+ DeviceMemory<std::complex<float>> *, int, blas::ComputationType,
+ blas::AlgorithmType>
+ impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
+ m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
+ algorithm, output_profile_result);
+}
+
+Stream &Stream::ThenBlasGemmWithAlgorithm(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &b, int ldb,
+ std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
+ blas::ComputationType computation_type, blas::AlgorithmType algorithm,
+ blas::ProfileResult *output_profile_result) {
+ VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
+ PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
+ PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
+ PARAM(algorithm));
+
+ ThenBlasWithProfileImpl<
+ blas::Transpose, blas::Transpose, uint64, uint64, uint64,
+ std::complex<double>, const DeviceMemory<std::complex<double>> &, int,
+ const DeviceMemory<std::complex<double>> &, int, std::complex<double>,
+ DeviceMemory<std::complex<double>> *, int, blas::ComputationType,
+ blas::AlgorithmType>
+ impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
+ m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
+ algorithm, output_profile_result);
+}
+
Stream &Stream::ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
uint64 n, std::complex<float> alpha,
const DeviceMemory<std::complex<float>> &a,
diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h
index a092e87004..f22fba1d74 100644
--- a/tensorflow/stream_executor/stream.h
+++ b/tensorflow/stream_executor/stream.h
@@ -1179,6 +1179,47 @@ class Stream {
std::complex<double> beta,
DeviceMemory<std::complex<double>> *c, int ldc);
+ // See BlasSupport::DoBlasGemmWithAlgorithm.
+ Stream &ThenBlasGemmWithAlgorithm(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, const Eigen::half &alpha, const DeviceMemory<Eigen::half> &a,
+ int lda, const DeviceMemory<Eigen::half> &b, int ldb,
+ const Eigen::half &beta, DeviceMemory<Eigen::half> *c, int ldc,
+ blas::ComputationType computation_type, blas::AlgorithmType algorithm,
+ blas::ProfileResult *output_profile_result);
+ Stream &ThenBlasGemmWithAlgorithm(blas::Transpose transa,
+ blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, float alpha,
+ const DeviceMemory<float> &a, int lda,
+ const DeviceMemory<float> &b, int ldb,
+ float beta, DeviceMemory<float> *c, int ldc,
+ blas::ComputationType computation_type,
+ blas::AlgorithmType algorithm,
+ blas::ProfileResult *output_profile_result);
+ Stream &ThenBlasGemmWithAlgorithm(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
+ const DeviceMemory<double> &b, int ldb, double beta,
+ DeviceMemory<double> *c, int ldc, blas::ComputationType computation_type,
+ blas::AlgorithmType algorithm,
+ blas::ProfileResult *output_profile_result);
+ Stream &ThenBlasGemmWithAlgorithm(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &b, int ldb,
+ std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
+ blas::ComputationType computation_type, blas::AlgorithmType algorithm,
+ blas::ProfileResult *output_profile_result);
+ Stream &ThenBlasGemmWithAlgorithm(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &b, int ldb,
+ std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
+ blas::ComputationType computation_type, blas::AlgorithmType algorithm,
+ blas::ProfileResult *output_profile_result);
+
// See BlasSupport::DoBlasGemmBatched.
Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
uint64 m, uint64 n, uint64 k, float alpha,
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc
index c498eecb3c..42fcd5867c 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.cc
+++ b/tensorflow/stream_executor/stream_executor_pimpl.cc
@@ -310,6 +310,15 @@ bool StreamExecutor::GetConvolveBackwardFilterAlgorithms(
return dnn_support->GetConvolveBackwardFilterAlgorithms(out_algorithms);
}
+bool StreamExecutor::GetBlasGemmAlgorithms(
+ std::vector<blas::AlgorithmType> *out_algorithms) {
+ blas::BlasSupport *blas_support = AsBlas();
+ if (!blas_support) {
+ return false;
+ }
+ return blas_support->GetBlasGemmAlgorithms(out_algorithms);
+}
+
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
StreamExecutor::createRnnDescriptor(
int num_layers, int hidden_size, int input_size,
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h
index 29ba63af05..5c52afa794 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.h
+++ b/tensorflow/stream_executor/stream_executor_pimpl.h
@@ -353,6 +353,9 @@ class StreamExecutor {
bool GetConvolveBackwardFilterAlgorithms(
std::vector<dnn::AlgorithmType> *out_algorithms);
+ // Get the list of supported algorithms for BLAS gemm.
+ bool GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> *out_algorithms);
+
// Create an RNN descriptor based on model shapes and configurations.
// The caller retains the ownership of the descriptor.
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(