diff options
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/gemm_thunk.cc | 231 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/gemm_thunk.h | 8 | ||||
-rw-r--r-- | tensorflow/stream_executor/blas.cc | 17 | ||||
-rw-r--r-- | tensorflow/stream_executor/blas.h | 145 | ||||
-rw-r--r-- | tensorflow/stream_executor/cuda/cuda_blas.cc | 231 | ||||
-rw-r--r-- | tensorflow/stream_executor/cuda/cuda_blas.h | 36 | ||||
-rw-r--r-- | tensorflow/stream_executor/stream.cc | 163 | ||||
-rw-r--r-- | tensorflow/stream_executor/stream.h | 41 | ||||
-rw-r--r-- | tensorflow/stream_executor/stream_executor_pimpl.cc | 9 | ||||
-rw-r--r-- | tensorflow/stream_executor/stream_executor_pimpl.h | 3 |
10 files changed, 828 insertions, 56 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index 46f316e844..a80f969b9d 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -45,57 +45,163 @@ struct MatrixDescriptor { int64 num_cols; }; -// Performs a gemm call on lhs_matrix and rhs_matrix and stores the result to -// output_matrix. +// Performs a gemm call without an explicit algorithm on lhs_matrix and +// rhs_matrix, and stores the result to output_matrix. template <typename Element> -tensorflow::Status DoGemm(MatrixDescriptor lhs_matrix, - MatrixDescriptor rhs_matrix, - MatrixDescriptor output_matrix, se::Stream* stream) { +bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, + MatrixDescriptor output_matrix, se::Stream* stream) { DCHECK(!output_matrix.transpose); se::DeviceMemory<Element> lhs_data(lhs_matrix.data); se::DeviceMemory<Element> rhs_data(rhs_matrix.data); se::DeviceMemory<Element> output_data(output_matrix.data); - bool launch_ok = - stream - ->ThenBlasGemm( - lhs_matrix.transpose ? se::blas::Transpose::kTranspose - : se::blas::Transpose::kNoTranspose, - rhs_matrix.transpose ? se::blas::Transpose::kTranspose - : se::blas::Transpose::kNoTranspose, - output_matrix.num_rows, output_matrix.num_cols, - lhs_matrix.transpose - ? lhs_matrix.num_rows - : lhs_matrix.num_cols, // Size of the reduce dimension. - /*alpha=*/1.0, - lhs_data, - lhs_matrix.num_rows, // The leading dimension of LHS. - rhs_data, - rhs_matrix.num_rows, // The leading dimension of RHS. - /*beta=*/0.0, &output_data, - output_matrix - .num_rows) // The leading dimension of the output matrix. - .ok(); - if (!launch_ok) { - return InternalError("Unable to launch cuBLAS gemm on stream %p", stream); + auto lhs_transpose = lhs_matrix.transpose ? se::blas::Transpose::kTranspose + : se::blas::Transpose::kNoTranspose; + auto rhs_transpose = rhs_matrix.transpose ? se::blas::Transpose::kTranspose + : se::blas::Transpose::kNoTranspose; + auto k = lhs_matrix.transpose ? lhs_matrix.num_rows : lhs_matrix.num_cols; + + return stream + ->ThenBlasGemm( + lhs_transpose, rhs_transpose, output_matrix.num_rows, + output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/1.0, + lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data, + /*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/0.0, + &output_data, /*leading dim of output=*/output_matrix.num_rows) + .ok(); +} + +// Like DoGemm, but takes an explicit computation type and algorithm. +// computation_type specifies the type of intermediate values generated during +// the matmul (e.g. your input/output matricies could be f16s but you could do +// computations with f32s). algorithm is an opaque identifier which functions +// as a hint to cublas. +// +// Not all algorithms are valid for all matrix sizes, and not all CUDA versions +// and GPUs even support gemm-with-algorithm. So expect that this may fail +// unless you've already checked that it works for this particular GPU + input +// size. +// +// If you pass a non-null ProfileResult, this will always return true (assuming +// the Stream was valid to begin with); check the is_valid property of the +// ProfileResult to see whether the call actually succeeded. +template <typename Element> +bool DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix, + MatrixDescriptor rhs_matrix, + MatrixDescriptor output_matrix, + se::blas::ComputationType computation_type, + se::blas::AlgorithmType algorithm, se::Stream* stream, + se::blas::ProfileResult* output_profile_result) { + DCHECK(!output_matrix.transpose); + + se::DeviceMemory<Element> lhs_data(lhs_matrix.data); + se::DeviceMemory<Element> rhs_data(rhs_matrix.data); + se::DeviceMemory<Element> output_data(output_matrix.data); + + auto lhs_transpose = lhs_matrix.transpose ? se::blas::Transpose::kTranspose + : se::blas::Transpose::kNoTranspose; + auto rhs_transpose = rhs_matrix.transpose ? se::blas::Transpose::kTranspose + : se::blas::Transpose::kNoTranspose; + auto k = lhs_matrix.transpose ? lhs_matrix.num_rows : lhs_matrix.num_cols; + + return stream + ->ThenBlasGemmWithAlgorithm( + lhs_transpose, rhs_transpose, output_matrix.num_rows, + output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/1.0, + lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data, + /*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/0.0, + &output_data, /*leading dim of output=*/output_matrix.num_rows, + computation_type, algorithm, output_profile_result) + .ok(); +} + +// Experimentally tries to pick the best algorithm for the given gemm. +// +// This may fail under perfectly normal circumstances. In particular, it will +// fail if the program was built with < CUDA 8 or if we're using a gpu older +// than sm_50 -- in both cases, cublas doesn't support gemm-with-algorithm at +// all. +template <typename Element> +StatusOr<se::blas::AlgorithmType> DoGemmAutotune( + MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, + MatrixDescriptor output_matrix, se::blas::ComputationType computation_type, + se::Stream* stream) { + std::vector<se::blas::AlgorithmType> algorithms; + CHECK(stream->parent()->GetBlasGemmAlgorithms(&algorithms)); + + se::blas::ProfileResult best_result; + for (auto algorithm : algorithms) { + se::blas::ProfileResult profile_result; + // We expect GemmWithAlgorithm to fail sometimes -- in fact, it will fail + // for all algorithms if we're targeting < sm_50. But because we pass a + // non-null ProfileResult, DoGemmWithAlgorithm should always return true, + // and the actual success-ness is returned in ProfileResult::is_valid. + DCHECK(DoGemmWithAlgorithm<Element>(lhs_matrix, rhs_matrix, output_matrix, + computation_type, algorithm, stream, + &profile_result)); + + if (profile_result.is_valid() && profile_result.elapsed_time_in_ms() < + best_result.elapsed_time_in_ms()) { + best_result = profile_result; + } } - return tensorflow::Status::OK(); + + if (best_result.is_valid()) { + return best_result.algorithm(); + } + + return InternalError( + "Unable to autotune cuBLAS gemm on stream %p; none of the %zu algorithms " + "ran successfully", + stream, algorithms.size()); } -// Return, if the given type is a valid Gemm elemental type, the executor for -// that type, else null. -// TODO(b/27202055): consider more element types. -std::function<tensorflow::Status(MatrixDescriptor, MatrixDescriptor, - MatrixDescriptor, se::Stream*)> -FindGemmExecutor(PrimitiveType type) { +// Helper functions to go from a PrimitiveType to a templated version of +// DoGemm/DoGemmWithAlgorithm/DoGemmAutotune. +auto GetGemmFn(PrimitiveType type) -> decltype(&DoGemm<float>) { switch (type) { case F32: return &DoGemm<float>; case F64: return &DoGemm<double>; default: - return nullptr; + LOG(FATAL) << "Unsupported type."; + } +} +auto GetGemmWithAlgorithmFn(PrimitiveType type) + -> decltype(&DoGemmWithAlgorithm<float>) { + switch (type) { + case F32: + return &DoGemmWithAlgorithm<float>; + case F64: + return &DoGemmWithAlgorithm<double>; + default: + LOG(FATAL) << "Unsupported type."; + } +} +auto GetGemmAutotuneFn(PrimitiveType type) -> decltype(&DoGemmAutotune<float>) { + switch (type) { + case F32: + return &DoGemmAutotune<float>; + case F64: + return &DoGemmAutotune<double>; + default: + LOG(FATAL) << "Unsupported type."; + } +} + +// Converts from an XLA PrimitiveType to a blas::ComputationType, which is used +// to specify the precision with which matmul computations should be performed, +// separately from the precision of the inputs and result. +se::blas::ComputationType GetBlasComputationType(PrimitiveType type) { + switch (type) { + case F32: + return se::blas::ComputationType::kF32; + case F64: + return se::blas::ComputationType::kF64; + default: + LOG(FATAL) << "Unsupported type."; } } @@ -120,8 +226,6 @@ GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer, tensorflow::Status GemmThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream) { VLOG(2) << "Executing a GemmThunk"; - auto executor = FindGemmExecutor(output_shape_.element_type()); - DCHECK(executor != nullptr); se::DeviceMemoryBase lhs_data = buffer_allocations.GetDeviceAddress(lhs_buffer_); @@ -172,17 +276,66 @@ tensorflow::Status GemmThunk::ExecuteOnStream( make_descriptor(lhs_data, lhs_shape_, transpose_lhs_); const MatrixDescriptor rhs_descriptor = make_descriptor(rhs_data, rhs_shape_, transpose_rhs_); + + // Dispatches to a regular cublas gemm, a gemm-with-algorithm, or attempts to + // autotune this gemm to figure out the best algorithm. + auto launch = [this](MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, + MatrixDescriptor output_matrix, se::Stream* stream) { + PrimitiveType element_type = output_shape_.element_type(); + se::blas::ComputationType computation_type = + GetBlasComputationType(element_type); + + const string& device_name = stream->parent()->GetDeviceDescription().name(); + auto autotune_it = autotune_results_.find(device_name); + if (autotune_it == autotune_results_.end()) { + StatusOr<se::blas::AlgorithmType> best_algorithm = + GetGemmAutotuneFn(element_type)(lhs_matrix, rhs_matrix, output_matrix, + computation_type, stream); + autotune_it = + autotune_results_.insert({device_name, best_algorithm}).first; + + if (autotune_it->second.ok()) { + VLOG(2) << "Autotune on GemmThunk " << this + << " successful; best algorithm is " + << best_algorithm.ValueOrDie(); + } else { + VLOG(2) << "Autotune on GemmThunk " << this + << " unsuccessful. Will use generic gemm."; + } + } + + const StatusOr<se::blas::AlgorithmType>& best_algorithm = + autotune_it->second; + if (best_algorithm.ok()) { + auto algorithm = best_algorithm.ValueOrDie(); + VLOG(2) << "Using algorithm " << algorithm + << " chosen by autotuning on GemmThunk " << this; + return GetGemmWithAlgorithmFn(element_type)( + lhs_matrix, rhs_matrix, output_matrix, computation_type, algorithm, + stream, + /*output_profile_result=*/nullptr); + } + return GetGemmFn(element_type)(lhs_matrix, rhs_matrix, output_matrix, + stream); + }; + + bool launch_ok; if (output_shape_.layout().minor_to_major(0) == 0) { - return executor( + launch_ok = launch( lhs_descriptor, rhs_descriptor, MatrixDescriptor(output_data, false, output_num_rows, output_num_cols), stream); } else { - return executor( + launch_ok = launch( rhs_descriptor, lhs_descriptor, MatrixDescriptor(output_data, false, output_num_cols, output_num_rows), stream); } + + if (!launch_ok) { + return InternalError("Unable to launch cuBLAS gemm on stream %p", stream); + } + return tensorflow::Status::OK(); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h index b540da65b4..983cb87292 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h @@ -63,6 +63,14 @@ class GemmThunk : public Thunk { const bool transpose_lhs_; const bool transpose_rhs_; + + // Maps device names (StreamExecutor::DeviceDescription::name()) to autotune + // results. The map's value is the best algorithm we've found for this thunk + // on this device, or an error if none of the algorithms worked and we should + // use the regular gemm without an algorithm. + std::unordered_map<string, + StatusOr<::perftools::gputools::blas::AlgorithmType>> + autotune_results_; }; } // namespace gpu diff --git a/tensorflow/stream_executor/blas.cc b/tensorflow/stream_executor/blas.cc index 239e3fce01..a59a1dda71 100644 --- a/tensorflow/stream_executor/blas.cc +++ b/tensorflow/stream_executor/blas.cc @@ -67,6 +67,23 @@ string SideString(Side s) { } } +string ComputationTypeString(ComputationType ty) { + switch (ty) { + case ComputationType::kF16: + return "f16"; + case ComputationType::kF32: + return "f32"; + case ComputationType::kF64: + return "f64"; + case ComputationType::kComplexF32: + return "complex f32"; + case ComputationType::kComplexF64: + return "complex f64"; + default: + LOG(FATAL) << "Unknown ComputationType " << static_cast<int32>(ty); + } +} + } // namespace blas } // namespace gputools } // namespace perftools diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h index 853bbfd1c7..07a0f7ccd6 100644 --- a/tensorflow/stream_executor/blas.h +++ b/tensorflow/stream_executor/blas.h @@ -88,6 +88,46 @@ enum class Side { kLeft, kRight }; // Returns a name for s. string SideString(Side s); +// Type with which intermediate computations of a blas routine are performed. +// +// Some blas calls can perform computations with a type that's different than +// the type of their inputs/outputs. This lets you e.g. multiply two matricies +// of int8s using float32s to store the matmul's intermediate values. +enum class ComputationType { + kF16, // 16-bit floating-point + kF32, // 32-bit floating-point + kF64, // 64-bit floating-point + kComplexF32, // Complex number comprised of two f32s. + kComplexF64 // Complex number comprised of two f64s. +}; + +// Converts a ComputationType to a string. +string ComputationTypeString(ComputationType ty); + +// Opaque identifier for an "algorithm" used by a blas routine. This functions +// as a hint to the blas library. +typedef int64 AlgorithmType; + +// Describes the result of a performance experiment, usually timing the speed of +// a particular AlgorithmType. +// +// If the call we were benchmarking failed (a common occurrence; not all +// algorithms are valid for all calls), is_valid() will be false. +class ProfileResult { + public: + bool is_valid() const { return is_valid_; } + void set_is_valid(bool val) { is_valid_ = val; } + AlgorithmType algorithm() const { return algorithm_; } + void set_algorithm(AlgorithmType val) { algorithm_ = val; } + float elapsed_time_in_ms() const { return elapsed_time_in_ms_; } + void set_elapsed_time_in_ms(float val) { elapsed_time_in_ms_ = val; } + + private: + bool is_valid_ = false; + AlgorithmType algorithm_ = 0; + float elapsed_time_in_ms_ = std::numeric_limits<float>::max(); +}; + // BLAS support interface -- this can be derived from a GPU executor when the // underlying platform has an BLAS library implementation available. See // StreamExecutor::AsBlas(). @@ -856,11 +896,10 @@ class BlasSupport { // batched version of the half-precision interface. virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, - float alpha, - const DeviceMemory<Eigen::half> &a, int lda, - const DeviceMemory<Eigen::half> &b, int ldb, - float beta, - DeviceMemory<Eigen::half> *c, int ldc) = 0; + float alpha, const DeviceMemory<Eigen::half> &a, + int lda, const DeviceMemory<Eigen::half> &b, int ldb, + float beta, DeviceMemory<Eigen::half> *c, + int ldc) = 0; virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda, @@ -886,6 +925,61 @@ class BlasSupport { std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc) = 0; + // Gets a list of supported algorithms for DoBlasGemmWithAlgorithm. Note that + // any or all of these algorithms may still be + virtual bool GetBlasGemmAlgorithms( + std::vector<AlgorithmType> *out_algorithms) = 0; + + // Like DoBlasGemm, but accepts an algorithm and an compute type. + // + // The compute type lets you say (e.g.) that the inputs and outputs are + // Eigen::halfs, but you want the internal computations to be done with + // float32 precision. + // + // Note the subtle difference in the version that accepts Eigen:::half -- + // alpha and beta have type const Eigen::half&, not float. + // + // If output_profile_result is not null, a failure here does not put the + // stream in a failure state. Instead, success/failure is indicated by + // output_profile_result->is_valid(). This lets you use this function for + // choosing the best algorithm among many (some of which may fail) without + // creating a new Stream for each attempt. + virtual bool DoBlasGemmWithAlgorithm( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, const Eigen::half &alpha, + const DeviceMemory<Eigen::half> &a, int lda, + const DeviceMemory<Eigen::half> &b, int ldb, const Eigen::half &beta, + DeviceMemory<Eigen::half> *c, int ldc, ComputationType computation_type, + AlgorithmType algorithm, ProfileResult *output_profile_result) = 0; + virtual bool DoBlasGemmWithAlgorithm( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda, + const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c, + int ldc, ComputationType computation_type, AlgorithmType algorithm, + ProfileResult *output_profile_result) = 0; + virtual bool DoBlasGemmWithAlgorithm( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda, + const DeviceMemory<double> &b, int ldb, double beta, + DeviceMemory<double> *c, int ldc, ComputationType computation_type, + AlgorithmType algorithm, ProfileResult *output_profile_result) = 0; + virtual bool DoBlasGemmWithAlgorithm( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, std::complex<float> alpha, + const DeviceMemory<std::complex<float>> &a, int lda, + const DeviceMemory<std::complex<float>> &b, int ldb, + std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, + ComputationType computation_type, AlgorithmType algorithm, + ProfileResult *output_profile_result) = 0; + virtual bool DoBlasGemmWithAlgorithm( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, std::complex<double> alpha, + const DeviceMemory<std::complex<double>> &a, int lda, + const DeviceMemory<std::complex<double>> &b, int ldb, + std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc, + ComputationType computation_type, AlgorithmType algorithm, + ProfileResult *output_profile_result) = 0; + // Computes a batch of matrix-matrix product with general matrices. // This is a batched version of DoBlasGemm. // The batched GEMM computes matrix product for each input/output in a, b, @@ -1641,6 +1735,47 @@ class BlasSupport { const DeviceMemory<std::complex<double>> &b, int ldb, \ std::complex<double> beta, \ DeviceMemory<std::complex<double>> *c, int ldc) override; \ + bool GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> *out_algorithms) \ + override; \ + bool DoBlasGemmWithAlgorithm( \ + Stream *stream, blas::Transpose transa, blas::Transpose transb, \ + uint64 m, uint64 n, uint64 k, const Eigen::half &alpha, \ + const DeviceMemory<Eigen::half> &a, int lda, \ + const DeviceMemory<Eigen::half> &b, int ldb, const Eigen::half &beta, \ + DeviceMemory<Eigen::half> *c, int ldc, \ + blas::ComputationType computation_type, blas::AlgorithmType algorithm, \ + blas::ProfileResult *output_profile_result) override; \ + bool DoBlasGemmWithAlgorithm( \ + Stream *stream, blas::Transpose transa, blas::Transpose transb, \ + uint64 m, uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, \ + int lda, const DeviceMemory<float> &b, int ldb, float beta, \ + DeviceMemory<float> *c, int ldc, blas::ComputationType computation_type, \ + blas::AlgorithmType algorithm, \ + blas::ProfileResult *output_profile_result) override; \ + bool DoBlasGemmWithAlgorithm( \ + Stream *stream, blas::Transpose transa, blas::Transpose transb, \ + uint64 m, uint64 n, uint64 k, double alpha, \ + const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b, \ + int ldb, double beta, DeviceMemory<double> *c, int ldc, \ + blas::ComputationType computation_type, blas::AlgorithmType algorithm, \ + blas::ProfileResult *output_profile_result) override; \ + bool DoBlasGemmWithAlgorithm( \ + Stream *stream, blas::Transpose transa, blas::Transpose transb, \ + uint64 m, uint64 n, uint64 k, std::complex<float> alpha, \ + const DeviceMemory<std::complex<float>> &a, int lda, \ + const DeviceMemory<std::complex<float>> &b, int ldb, \ + std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, \ + blas::ComputationType computation_type, blas::AlgorithmType algorithm, \ + blas::ProfileResult *output_profile_result) override; \ + bool DoBlasGemmWithAlgorithm( \ + Stream *stream, blas::Transpose transa, blas::Transpose transb, \ + uint64 m, uint64 n, uint64 k, std::complex<double> alpha, \ + const DeviceMemory<std::complex<double>> &a, int lda, \ + const DeviceMemory<std::complex<double>> &b, int ldb, \ + std::complex<double> beta, DeviceMemory<std::complex<double>> *c, \ + int ldc, blas::ComputationType computation_type, \ + blas::AlgorithmType algorithm, \ + blas::ProfileResult *output_profile_result) override; \ bool DoBlasGemmBatched( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64 m, uint64 n, uint64 k, float alpha, \ diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc index 3df366bc4d..2c650afc70 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.cc +++ b/tensorflow/stream_executor/cuda/cuda_blas.cc @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/stream_executor/cuda/cuda_helpers.h" #include "tensorflow/stream_executor/cuda/cuda_platform_id.h" #include "tensorflow/stream_executor/cuda/cuda_stream.h" +#include "tensorflow/stream_executor/cuda/cuda_timer.h" #include "tensorflow/stream_executor/device_memory.h" #include "tensorflow/stream_executor/lib/env.h" #include "tensorflow/stream_executor/lib/initialize.h" @@ -262,6 +263,10 @@ CUBLAS_BLAS_ROUTINE_EACH(PERFTOOLS_GPUTOOLS_CUBLAS_V2_WRAP) PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasSgemmEx) #endif +#if CUDA_VERSION >= 8000 +PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasGemmEx) +#endif + } // namespace wrap static string ToString(cublasStatus_t status) { @@ -282,6 +287,12 @@ static string ToString(cublasStatus_t status) { return "CUBLAS_STATUS_EXECUTION_FAILED"; case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR"; +#if CUDA_VERSION >= 8000 + case CUBLAS_STATUS_NOT_SUPPORTED: + return "CUBLAS_STATUS_NOT_SUPPORTED"; + case CUBLAS_STATUS_LICENSE_ERROR: + return "CUBLAS_STATUS_LICENSE_ERROR"; +#endif default: return port::StrCat("<invalid cublas status: ", status, ">"); } @@ -431,11 +442,89 @@ cublasSideMode_t CUDABlasSide(blas::Side side) { } } +// CUDADataType<T>::type translates from a C++ type (e.g. float) to a +// cudaDataType_t (e.g. CUDA_R_32F). CUDAComputationType(ty) translates from a +// blas::ComputationType to a cudaDataType_t. +// +// These are used to build the argument type and computation type args to +// cublasGemmEx. cublasGemmEx and cudaDataType_t are available only on +// CUDA >= 8.0. +#if CUDA_VERSION >= 8000 +template <typename T> +struct CUDADataType; + +template <> +struct CUDADataType<Eigen::half> { + static constexpr cudaDataType_t type = SE_CUDA_DATA_HALF; +}; + +template <> +struct CUDADataType<std::complex<Eigen::half>> { + static constexpr cudaDataType_t type = CUDA_C_16F; +}; + +template <> +struct CUDADataType<float> { + static constexpr cudaDataType_t type = CUDA_R_32F; +}; + +template <> +struct CUDADataType<std::complex<float>> { + static constexpr cudaDataType_t type = CUDA_C_32F; +}; + +template <> +struct CUDADataType<double> { + static constexpr cudaDataType_t type = CUDA_R_64F; +}; + +template <> +struct CUDADataType<std::complex<double>> { + static constexpr cudaDataType_t type = CUDA_C_64F; +}; + +template <> +struct CUDADataType<int8> { + static constexpr cudaDataType_t type = CUDA_R_8I; +}; + +template <> +struct CUDADataType<std::complex<int8>> { + static constexpr cudaDataType_t type = CUDA_C_8I; +}; + +template <> +struct CUDADataType<uint8> { + static constexpr cudaDataType_t type = CUDA_R_8U; +}; + +template <> +struct CUDADataType<std::complex<uint8>> { + static constexpr cudaDataType_t type = CUDA_C_8U; +}; + +cudaDataType_t CUDAComputationType(blas::ComputationType ty) { + switch (ty) { + case blas::ComputationType::kF16: + return CUDA_R_16F; + case blas::ComputationType::kF32: + return CUDA_R_32F; + case blas::ComputationType::kF64: + return CUDA_R_64F; + case blas::ComputationType::kComplexF32: + return CUDA_C_32F; + case blas::ComputationType::kComplexF64: + return CUDA_C_64F; + } +} +#endif + } // namespace template <typename FuncT, typename... Args> -bool CUDABlas::DoBlasInternal(FuncT cublas_func, Stream *stream, - bool pointer_mode_host, Args... args) { +bool CUDABlas::DoBlasInternalImpl(FuncT cublas_func, Stream *stream, + bool pointer_mode_host, bool err_on_failure, + Args... args) { mutex_lock lock{mu_}; CHECK(blas_ != nullptr); @@ -450,13 +539,11 @@ bool CUDABlas::DoBlasInternal(FuncT cublas_func, Stream *stream, } cublasStatus_t ret = cublas_func(parent_, blas_, args...); - if (ret != CUBLAS_STATUS_SUCCESS) { + if (err_on_failure && ret != CUBLAS_STATUS_SUCCESS) { LOG(ERROR) << "failed to run cuBLAS routine " << cublas_func.kName << ": " << ToString(ret); - return false; } - - return true; + return ret == CUBLAS_STATUS_SUCCESS; } bool CUDABlas::DoBlasAsum(Stream *stream, uint64 elem_count, @@ -1762,6 +1849,138 @@ bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa, CUDAComplex(CUDAMemoryMutable(c)), ldc); } +template <typename T> +bool CUDABlas::DoBlasGemmWithAlgorithmImpl( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, const T &alpha, const DeviceMemory<T> &a, int lda, + const DeviceMemory<T> &b, int ldb, const T &beta, DeviceMemory<T> *c, + int ldc, blas::ComputationType computation_type, + blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) { +// CUDA < version 8 and GPUs < sm_50 don't support cublasGemmEx. +#if CUDA_VERSION < 8000 + return false; +#else + int cc_major, cc_minor; + if (stream->parent()->GetDeviceDescription().cuda_compute_capability( + &cc_major, &cc_minor) && + cc_major < 5) { + return false; + } + + struct TimerDeleter { + void operator()(CUDATimer *t) { + t->Destroy(); + delete t; + } + }; + std::unique_ptr<CUDATimer, TimerDeleter> timer; + if (output_profile_result != nullptr) { + timer.reset(new CUDATimer(parent_)); + if (!timer->Init() || !timer->Start(AsCUDAStream(stream))) { + return false; + } + } + + cudaDataType_t data_type = CUDADataType<T>::type; + bool result = DoBlasInternalFailureOK( + wrap::cublasGemmEx, stream, /* pointer_mode_host = */ true, + CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha, + CUDAMemory(a), data_type, lda, CUDAMemory(b), data_type, ldb, &beta, + CUDAMemoryMutable(c), data_type, ldc, + CUDAComputationType(computation_type), + static_cast<cublasGemmAlgo_t>(algorithm)); + + if (timer != nullptr && result) { + // CUDATimer will CHECK-fail if we Stop() it while the stream is in an error + // state. + if (!timer->Stop(AsCUDAStream(stream))) { + return false; + } + output_profile_result->set_is_valid(true); + output_profile_result->set_algorithm(algorithm); + output_profile_result->set_elapsed_time_in_ms( + timer->GetElapsedMilliseconds()); + } + return result; +#endif +} + +bool CUDABlas::GetBlasGemmAlgorithms( + std::vector<blas::AlgorithmType> *out_algorithms) { +// cublasGemmAlgo_t (and the function that accepts this type, cublasGemmEx) +// were first introduced in CUDA 8. +#if CUDA_VERSION >= 8000 + for (cublasGemmAlgo_t algo : + {CUBLAS_GEMM_DFALT, CUBLAS_GEMM_ALGO0, CUBLAS_GEMM_ALGO1, + CUBLAS_GEMM_ALGO2, CUBLAS_GEMM_ALGO3, CUBLAS_GEMM_ALGO4, + CUBLAS_GEMM_ALGO5, CUBLAS_GEMM_ALGO6, CUBLAS_GEMM_ALGO7}) { + out_algorithms->push_back(algo); + } +#endif + return true; +} + +bool CUDABlas::DoBlasGemmWithAlgorithm( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, const Eigen::half &alpha, + const DeviceMemory<Eigen::half> &a, int lda, + const DeviceMemory<Eigen::half> &b, int ldb, const Eigen::half &beta, + DeviceMemory<Eigen::half> *c, int ldc, + blas::ComputationType computation_type, blas::AlgorithmType algorithm, + blas::ProfileResult *output_profile_result) { + return DoBlasGemmWithAlgorithmImpl( + stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + computation_type, algorithm, output_profile_result); +} + +bool CUDABlas::DoBlasGemmWithAlgorithm( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda, + const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c, + int ldc, blas::ComputationType computation_type, + blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) { + return DoBlasGemmWithAlgorithmImpl( + stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + computation_type, algorithm, output_profile_result); +} + +bool CUDABlas::DoBlasGemmWithAlgorithm( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda, + const DeviceMemory<double> &b, int ldb, double beta, + DeviceMemory<double> *c, int ldc, blas::ComputationType computation_type, + blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) { + return DoBlasGemmWithAlgorithmImpl( + stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + computation_type, algorithm, output_profile_result); +} + +bool CUDABlas::DoBlasGemmWithAlgorithm( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, std::complex<float> alpha, + const DeviceMemory<std::complex<float>> &a, int lda, + const DeviceMemory<std::complex<float>> &b, int ldb, + std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, + blas::ComputationType computation_type, blas::AlgorithmType algorithm, + blas::ProfileResult *output_profile_result) { + return DoBlasGemmWithAlgorithmImpl( + stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + computation_type, algorithm, output_profile_result); +} + +bool CUDABlas::DoBlasGemmWithAlgorithm( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, std::complex<double> alpha, + const DeviceMemory<std::complex<double>> &a, int lda, + const DeviceMemory<std::complex<double>> &b, int ldb, + std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc, + blas::ComputationType computation_type, blas::AlgorithmType algorithm, + blas::ProfileResult *output_profile_result) { + return DoBlasGemmWithAlgorithmImpl( + stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + computation_type, algorithm, output_profile_result); +} + template <typename T, typename FuncT> port::Status CUDABlas::DoBlasGemmBatchedInternal( FuncT cublas_func, Stream *stream, blas::Transpose transa, diff --git a/tensorflow/stream_executor/cuda/cuda_blas.h b/tensorflow/stream_executor/cuda/cuda_blas.h index 1226df6d65..6a33cd746b 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.h +++ b/tensorflow/stream_executor/cuda/cuda_blas.h @@ -79,10 +79,27 @@ class CUDABlas : public blas::BlasSupport { // stream: Stream to enqueue the BLAS operation onto. // pointer_mode_host: Indicate if the pointer to a scalar value is from host // (true) or device (false). + // err_on_failure: Whether to print an error if the cublas function fails. // args: Arguments of cuBLAS function. template <typename FuncT, typename... Args> + bool DoBlasInternalImpl(FuncT cublas_func, Stream *stream, + bool pointer_mode_host, bool err_on_failure, + Args... args); + + // Convenience functions that call DoBlasInternalImpl with different values + // for err_on_failure. + template <typename FuncT, typename... Args> bool DoBlasInternal(FuncT cublas_func, Stream *stream, bool pointer_mode_host, - Args... args); + Args... args) { + return DoBlasInternalImpl(cublas_func, stream, pointer_mode_host, + /*err_on_failure=*/true, args...); + } + template <typename FuncT, typename... Args> + bool DoBlasInternalFailureOK(FuncT cublas_func, Stream *stream, + bool pointer_mode_host, Args... args) { + return DoBlasInternalImpl(cublas_func, stream, pointer_mode_host, + /*err_on_failure=*/false, args...); + } // A helper function to implement DoBlasGemmBatched interfaces for generic // types. @@ -95,6 +112,23 @@ class CUDABlas : public blas::BlasSupport { const port::ArraySlice<DeviceMemory<T> *> &c_array, int ldc, int batch_count, ScratchAllocator *scratch_allocator); + // Helper function for implementing DoBlasGemmWithAlgorithm. + // + // We take alpha and beta by const reference because T might be Eigen::half, + // and we want to avoid pulling in a dependency on Eigen. When we pass the + // references to cublas, we essentially reinterpret_cast to __half, which is + // safe because Eigen::half inherits from __half. + template <typename T> + bool DoBlasGemmWithAlgorithmImpl(Stream *stream, blas::Transpose transa, + blas::Transpose transb, uint64 m, uint64 n, + uint64 k, const T &alpha, + const DeviceMemory<T> &a, int lda, + const DeviceMemory<T> &b, int ldb, + const T &beta, DeviceMemory<T> *c, int ldc, + blas::ComputationType computation_type, + blas::AlgorithmType algorithm, + blas::ProfileResult *output_profile_result); + // mutex that guards the cuBLAS handle for this device. mutex mu_; diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc index 9b67531689..f28f965f2c 100644 --- a/tensorflow/stream_executor/stream.cc +++ b/tensorflow/stream_executor/stream.cc @@ -75,6 +75,10 @@ string ToVlogString(blas::Diagonal d) { return blas::DiagonalString(d); } string ToVlogString(blas::Side s) { return blas::SideString(s); } +string ToVlogString(blas::ComputationType ty) { + return blas::ComputationTypeString(ty); +} + string ToVlogString(const void *ptr) { if (ptr == nullptr) { return "null"; @@ -109,6 +113,8 @@ string ToVlogString(const DeviceMemoryBase *memory) { return ToVlogString(*memory); } +string ToVlogString(const Eigen::half &h) { return port::StrCat(h); } + string ToVlogString(int i) { return port::StrCat(i); } string ToVlogString(uint32 i) { return port::StrCat(i); } @@ -1520,21 +1526,33 @@ struct ThenBlasImpl { // arguments except the first one of Stream* type. Stream &operator()(Stream *stream, bool (blas::BlasSupport::*blas_func)(Stream *, Args...), - Args... args); + Args... args) { + return Run(stream, blas_func, /*record_error=*/true, args...); + } + + // Like operator(), but only calls stream->CheckError() if record_error is + // true. + Stream &Run(Stream *stream, + bool (blas::BlasSupport::*blas_func)(Stream *, Args...), + bool record_error, Args... args); }; template <typename... Args> -Stream &ThenBlasImpl<Args...>::operator()( +Stream &ThenBlasImpl<Args...>::Run( Stream *stream, bool (blas::BlasSupport::*blas_func)(Stream *, Args...), - Args... args) { + bool record_error, Args... args) { if (stream->ok()) { + bool ok; if (blas::BlasSupport *blas = stream->parent_->AsBlas()) { - stream->CheckError((blas->*blas_func)(stream, args...)); + ok = (blas->*blas_func)(stream, args...); } else { - stream->CheckError(false); LOG(WARNING) << "attempting to perform BLAS operation using StreamExecutor " "without BLAS support"; + ok = false; + } + if (record_error) { + stream->CheckError(ok); } } return *stream; @@ -3215,6 +3233,141 @@ Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, alpha, a, lda, b, ldb, beta, c, ldc); } +namespace { +// Like ThenBlasImpl, except this expects the last argument of blas_func to be a +// blas::ProfileResult*. This functor doesn't put the stream into an error +// state if the op fails and the profile result is non-null. Instead, the +// error-ness is returned in the profile result itself. +template <typename... Args> +struct ThenBlasWithProfileImpl { + Stream &operator()(Stream *stream, + bool (blas::BlasSupport::*blas_func)( + Stream *, Args..., blas::ProfileResult *), + Args... args, blas::ProfileResult *profile_result) { + ThenBlasImpl<Args..., blas::ProfileResult *> Runner; + bool record_error = profile_result == nullptr; + return Runner.Run(stream, blas_func, record_error, args..., profile_result); + } +}; +} // anonymous namespace + +Stream &Stream::ThenBlasGemmWithAlgorithm( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, const Eigen::half &alpha, const DeviceMemory<Eigen::half> &a, + int lda, const DeviceMemory<Eigen::half> &b, int ldb, + const Eigen::half &beta, DeviceMemory<Eigen::half> *c, int ldc, + blas::ComputationType computation_type, blas::AlgorithmType algorithm, + blas::ProfileResult *output_profile_result) { + VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), + PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), + PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type), + PARAM(algorithm)); + + ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64, + uint64, const Eigen::half &, + const DeviceMemory<Eigen::half> &, int, + const DeviceMemory<Eigen::half> &, int, + const Eigen::half &, DeviceMemory<Eigen::half> *, int, + blas::ComputationType, blas::AlgorithmType> + impl; + return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb, + m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type, + algorithm, output_profile_result); +} + +Stream &Stream::ThenBlasGemmWithAlgorithm( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, float alpha, const DeviceMemory<float> &a, int lda, + const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c, + int ldc, blas::ComputationType computation_type, + blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) { + VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), + PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), + PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type), + PARAM(algorithm)); + + ThenBlasWithProfileImpl< + blas::Transpose, blas::Transpose, uint64, uint64, uint64, float, + const DeviceMemory<float> &, int, const DeviceMemory<float> &, int, float, + DeviceMemory<float> *, int, blas::ComputationType, blas::AlgorithmType> + impl; + return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb, + m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type, + algorithm, output_profile_result); +} + +Stream &Stream::ThenBlasGemmWithAlgorithm( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, double alpha, const DeviceMemory<double> &a, int lda, + const DeviceMemory<double> &b, int ldb, double beta, + DeviceMemory<double> *c, int ldc, blas::ComputationType computation_type, + blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) { + VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), + PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), + PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type), + PARAM(algorithm)); + + ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64, + uint64, double, const DeviceMemory<double> &, int, + const DeviceMemory<double> &, int, double, + DeviceMemory<double> *, int, blas::ComputationType, + blas::AlgorithmType> + impl; + return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb, + m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type, + algorithm, output_profile_result); +} + +Stream &Stream::ThenBlasGemmWithAlgorithm( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, std::complex<float> alpha, + const DeviceMemory<std::complex<float>> &a, int lda, + const DeviceMemory<std::complex<float>> &b, int ldb, + std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, + blas::ComputationType computation_type, blas::AlgorithmType algorithm, + blas::ProfileResult *output_profile_result) { + VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), + PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), + PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type), + PARAM(algorithm)); + + ThenBlasWithProfileImpl< + blas::Transpose, blas::Transpose, uint64, uint64, uint64, + std::complex<float>, const DeviceMemory<std::complex<float>> &, int, + const DeviceMemory<std::complex<float>> &, int, std::complex<float>, + DeviceMemory<std::complex<float>> *, int, blas::ComputationType, + blas::AlgorithmType> + impl; + return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb, + m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type, + algorithm, output_profile_result); +} + +Stream &Stream::ThenBlasGemmWithAlgorithm( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, std::complex<double> alpha, + const DeviceMemory<std::complex<double>> &a, int lda, + const DeviceMemory<std::complex<double>> &b, int ldb, + std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc, + blas::ComputationType computation_type, blas::AlgorithmType algorithm, + blas::ProfileResult *output_profile_result) { + VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), + PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), + PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type), + PARAM(algorithm)); + + ThenBlasWithProfileImpl< + blas::Transpose, blas::Transpose, uint64, uint64, uint64, + std::complex<double>, const DeviceMemory<std::complex<double>> &, int, + const DeviceMemory<std::complex<double>> &, int, std::complex<double>, + DeviceMemory<std::complex<double>> *, int, blas::ComputationType, + blas::AlgorithmType> + impl; + return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb, + m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type, + algorithm, output_profile_result); +} + Stream &Stream::ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m, uint64 n, std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a, diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h index a092e87004..f22fba1d74 100644 --- a/tensorflow/stream_executor/stream.h +++ b/tensorflow/stream_executor/stream.h @@ -1179,6 +1179,47 @@ class Stream { std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc); + // See BlasSupport::DoBlasGemmWithAlgorithm. + Stream &ThenBlasGemmWithAlgorithm( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, const Eigen::half &alpha, const DeviceMemory<Eigen::half> &a, + int lda, const DeviceMemory<Eigen::half> &b, int ldb, + const Eigen::half &beta, DeviceMemory<Eigen::half> *c, int ldc, + blas::ComputationType computation_type, blas::AlgorithmType algorithm, + blas::ProfileResult *output_profile_result); + Stream &ThenBlasGemmWithAlgorithm(blas::Transpose transa, + blas::Transpose transb, uint64 m, uint64 n, + uint64 k, float alpha, + const DeviceMemory<float> &a, int lda, + const DeviceMemory<float> &b, int ldb, + float beta, DeviceMemory<float> *c, int ldc, + blas::ComputationType computation_type, + blas::AlgorithmType algorithm, + blas::ProfileResult *output_profile_result); + Stream &ThenBlasGemmWithAlgorithm( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, double alpha, const DeviceMemory<double> &a, int lda, + const DeviceMemory<double> &b, int ldb, double beta, + DeviceMemory<double> *c, int ldc, blas::ComputationType computation_type, + blas::AlgorithmType algorithm, + blas::ProfileResult *output_profile_result); + Stream &ThenBlasGemmWithAlgorithm( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, std::complex<float> alpha, + const DeviceMemory<std::complex<float>> &a, int lda, + const DeviceMemory<std::complex<float>> &b, int ldb, + std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, + blas::ComputationType computation_type, blas::AlgorithmType algorithm, + blas::ProfileResult *output_profile_result); + Stream &ThenBlasGemmWithAlgorithm( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, std::complex<double> alpha, + const DeviceMemory<std::complex<double>> &a, int lda, + const DeviceMemory<std::complex<double>> &b, int ldb, + std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc, + blas::ComputationType computation_type, blas::AlgorithmType algorithm, + blas::ProfileResult *output_profile_result); + // See BlasSupport::DoBlasGemmBatched. Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, float alpha, diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc index c498eecb3c..42fcd5867c 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.cc +++ b/tensorflow/stream_executor/stream_executor_pimpl.cc @@ -310,6 +310,15 @@ bool StreamExecutor::GetConvolveBackwardFilterAlgorithms( return dnn_support->GetConvolveBackwardFilterAlgorithms(out_algorithms); } +bool StreamExecutor::GetBlasGemmAlgorithms( + std::vector<blas::AlgorithmType> *out_algorithms) { + blas::BlasSupport *blas_support = AsBlas(); + if (!blas_support) { + return false; + } + return blas_support->GetBlasGemmAlgorithms(out_algorithms); +} + port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> StreamExecutor::createRnnDescriptor( int num_layers, int hidden_size, int input_size, diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h index 29ba63af05..5c52afa794 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.h +++ b/tensorflow/stream_executor/stream_executor_pimpl.h @@ -353,6 +353,9 @@ class StreamExecutor { bool GetConvolveBackwardFilterAlgorithms( std::vector<dnn::AlgorithmType> *out_algorithms); + // Get the list of supported algorithms for BLAS gemm. + bool GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> *out_algorithms); + // Create an RNN descriptor based on model shapes and configurations. // The caller retains the ownership of the descriptor. port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor( |