aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/cuda/cuda_blas.cc
diff options
context:
space:
mode:
authorGravatar Justin Lebar <jlebar@google.com>2017-03-02 17:49:45 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-02 18:08:01 -0800
commit01194694948eb883e99af597d9dbbf3fc9f5c9e2 (patch)
treeab3517cf656259681283a90c6682c5b320ac36e3 /tensorflow/stream_executor/cuda/cuda_blas.cc
parente065b3093f4fec5a5f79ad9de81f6baab361962e (diff)
[XLA] [StreamExecutor] Tune GEMMs when possible.
cublas 8 adds the cublasGemmEx function, which lets you specify an explicit "algorithm" for the computation. This functions as an opaque tuning hint to cublas. This patch adds support for cublasGemmEx to StreamExecutor, and wires up XLA's GemmThunk to use the new function. This patch does not add GEMM autotuning support in TensorFlow proper, only XLA. Change: 149068961
Diffstat (limited to 'tensorflow/stream_executor/cuda/cuda_blas.cc')
-rw-r--r--tensorflow/stream_executor/cuda/cuda_blas.cc231
1 files changed, 225 insertions, 6 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc
index 3df366bc4d..2c650afc70 100644
--- a/tensorflow/stream_executor/cuda/cuda_blas.cc
+++ b/tensorflow/stream_executor/cuda/cuda_blas.cc
@@ -40,6 +40,7 @@ limitations under the License.
#include "tensorflow/stream_executor/cuda/cuda_helpers.h"
#include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
#include "tensorflow/stream_executor/cuda/cuda_stream.h"
+#include "tensorflow/stream_executor/cuda/cuda_timer.h"
#include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/lib/env.h"
#include "tensorflow/stream_executor/lib/initialize.h"
@@ -262,6 +263,10 @@ CUBLAS_BLAS_ROUTINE_EACH(PERFTOOLS_GPUTOOLS_CUBLAS_V2_WRAP)
PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasSgemmEx)
#endif
+#if CUDA_VERSION >= 8000
+PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasGemmEx)
+#endif
+
} // namespace wrap
static string ToString(cublasStatus_t status) {
@@ -282,6 +287,12 @@ static string ToString(cublasStatus_t status) {
return "CUBLAS_STATUS_EXECUTION_FAILED";
case CUBLAS_STATUS_INTERNAL_ERROR:
return "CUBLAS_STATUS_INTERNAL_ERROR";
+#if CUDA_VERSION >= 8000
+ case CUBLAS_STATUS_NOT_SUPPORTED:
+ return "CUBLAS_STATUS_NOT_SUPPORTED";
+ case CUBLAS_STATUS_LICENSE_ERROR:
+ return "CUBLAS_STATUS_LICENSE_ERROR";
+#endif
default:
return port::StrCat("<invalid cublas status: ", status, ">");
}
@@ -431,11 +442,89 @@ cublasSideMode_t CUDABlasSide(blas::Side side) {
}
}
+// CUDADataType<T>::type translates from a C++ type (e.g. float) to a
+// cudaDataType_t (e.g. CUDA_R_32F). CUDAComputationType(ty) translates from a
+// blas::ComputationType to a cudaDataType_t.
+//
+// These are used to build the argument type and computation type args to
+// cublasGemmEx. cublasGemmEx and cudaDataType_t are available only on
+// CUDA >= 8.0.
+#if CUDA_VERSION >= 8000
+template <typename T>
+struct CUDADataType;
+
+template <>
+struct CUDADataType<Eigen::half> {
+ static constexpr cudaDataType_t type = SE_CUDA_DATA_HALF;
+};
+
+template <>
+struct CUDADataType<std::complex<Eigen::half>> {
+ static constexpr cudaDataType_t type = CUDA_C_16F;
+};
+
+template <>
+struct CUDADataType<float> {
+ static constexpr cudaDataType_t type = CUDA_R_32F;
+};
+
+template <>
+struct CUDADataType<std::complex<float>> {
+ static constexpr cudaDataType_t type = CUDA_C_32F;
+};
+
+template <>
+struct CUDADataType<double> {
+ static constexpr cudaDataType_t type = CUDA_R_64F;
+};
+
+template <>
+struct CUDADataType<std::complex<double>> {
+ static constexpr cudaDataType_t type = CUDA_C_64F;
+};
+
+template <>
+struct CUDADataType<int8> {
+ static constexpr cudaDataType_t type = CUDA_R_8I;
+};
+
+template <>
+struct CUDADataType<std::complex<int8>> {
+ static constexpr cudaDataType_t type = CUDA_C_8I;
+};
+
+template <>
+struct CUDADataType<uint8> {
+ static constexpr cudaDataType_t type = CUDA_R_8U;
+};
+
+template <>
+struct CUDADataType<std::complex<uint8>> {
+ static constexpr cudaDataType_t type = CUDA_C_8U;
+};
+
+cudaDataType_t CUDAComputationType(blas::ComputationType ty) {
+ switch (ty) {
+ case blas::ComputationType::kF16:
+ return CUDA_R_16F;
+ case blas::ComputationType::kF32:
+ return CUDA_R_32F;
+ case blas::ComputationType::kF64:
+ return CUDA_R_64F;
+ case blas::ComputationType::kComplexF32:
+ return CUDA_C_32F;
+ case blas::ComputationType::kComplexF64:
+ return CUDA_C_64F;
+ }
+}
+#endif
+
} // namespace
template <typename FuncT, typename... Args>
-bool CUDABlas::DoBlasInternal(FuncT cublas_func, Stream *stream,
- bool pointer_mode_host, Args... args) {
+bool CUDABlas::DoBlasInternalImpl(FuncT cublas_func, Stream *stream,
+ bool pointer_mode_host, bool err_on_failure,
+ Args... args) {
mutex_lock lock{mu_};
CHECK(blas_ != nullptr);
@@ -450,13 +539,11 @@ bool CUDABlas::DoBlasInternal(FuncT cublas_func, Stream *stream,
}
cublasStatus_t ret = cublas_func(parent_, blas_, args...);
- if (ret != CUBLAS_STATUS_SUCCESS) {
+ if (err_on_failure && ret != CUBLAS_STATUS_SUCCESS) {
LOG(ERROR) << "failed to run cuBLAS routine " << cublas_func.kName << ": "
<< ToString(ret);
- return false;
}
-
- return true;
+ return ret == CUBLAS_STATUS_SUCCESS;
}
bool CUDABlas::DoBlasAsum(Stream *stream, uint64 elem_count,
@@ -1762,6 +1849,138 @@ bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
CUDAComplex(CUDAMemoryMutable(c)), ldc);
}
+template <typename T>
+bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, const T &alpha, const DeviceMemory<T> &a, int lda,
+ const DeviceMemory<T> &b, int ldb, const T &beta, DeviceMemory<T> *c,
+ int ldc, blas::ComputationType computation_type,
+ blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
+// CUDA < version 8 and GPUs < sm_50 don't support cublasGemmEx.
+#if CUDA_VERSION < 8000
+ return false;
+#else
+ int cc_major, cc_minor;
+ if (stream->parent()->GetDeviceDescription().cuda_compute_capability(
+ &cc_major, &cc_minor) &&
+ cc_major < 5) {
+ return false;
+ }
+
+ struct TimerDeleter {
+ void operator()(CUDATimer *t) {
+ t->Destroy();
+ delete t;
+ }
+ };
+ std::unique_ptr<CUDATimer, TimerDeleter> timer;
+ if (output_profile_result != nullptr) {
+ timer.reset(new CUDATimer(parent_));
+ if (!timer->Init() || !timer->Start(AsCUDAStream(stream))) {
+ return false;
+ }
+ }
+
+ cudaDataType_t data_type = CUDADataType<T>::type;
+ bool result = DoBlasInternalFailureOK(
+ wrap::cublasGemmEx, stream, /* pointer_mode_host = */ true,
+ CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha,
+ CUDAMemory(a), data_type, lda, CUDAMemory(b), data_type, ldb, &beta,
+ CUDAMemoryMutable(c), data_type, ldc,
+ CUDAComputationType(computation_type),
+ static_cast<cublasGemmAlgo_t>(algorithm));
+
+ if (timer != nullptr && result) {
+ // CUDATimer will CHECK-fail if we Stop() it while the stream is in an error
+ // state.
+ if (!timer->Stop(AsCUDAStream(stream))) {
+ return false;
+ }
+ output_profile_result->set_is_valid(true);
+ output_profile_result->set_algorithm(algorithm);
+ output_profile_result->set_elapsed_time_in_ms(
+ timer->GetElapsedMilliseconds());
+ }
+ return result;
+#endif
+}
+
+bool CUDABlas::GetBlasGemmAlgorithms(
+ std::vector<blas::AlgorithmType> *out_algorithms) {
+// cublasGemmAlgo_t (and the function that accepts this type, cublasGemmEx)
+// were first introduced in CUDA 8.
+#if CUDA_VERSION >= 8000
+ for (cublasGemmAlgo_t algo :
+ {CUBLAS_GEMM_DFALT, CUBLAS_GEMM_ALGO0, CUBLAS_GEMM_ALGO1,
+ CUBLAS_GEMM_ALGO2, CUBLAS_GEMM_ALGO3, CUBLAS_GEMM_ALGO4,
+ CUBLAS_GEMM_ALGO5, CUBLAS_GEMM_ALGO6, CUBLAS_GEMM_ALGO7}) {
+ out_algorithms->push_back(algo);
+ }
+#endif
+ return true;
+}
+
+bool CUDABlas::DoBlasGemmWithAlgorithm(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, const Eigen::half &alpha,
+ const DeviceMemory<Eigen::half> &a, int lda,
+ const DeviceMemory<Eigen::half> &b, int ldb, const Eigen::half &beta,
+ DeviceMemory<Eigen::half> *c, int ldc,
+ blas::ComputationType computation_type, blas::AlgorithmType algorithm,
+ blas::ProfileResult *output_profile_result) {
+ return DoBlasGemmWithAlgorithmImpl(
+ stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
+ computation_type, algorithm, output_profile_result);
+}
+
+bool CUDABlas::DoBlasGemmWithAlgorithm(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
+ const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
+ int ldc, blas::ComputationType computation_type,
+ blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
+ return DoBlasGemmWithAlgorithmImpl(
+ stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
+ computation_type, algorithm, output_profile_result);
+}
+
+bool CUDABlas::DoBlasGemmWithAlgorithm(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
+ const DeviceMemory<double> &b, int ldb, double beta,
+ DeviceMemory<double> *c, int ldc, blas::ComputationType computation_type,
+ blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
+ return DoBlasGemmWithAlgorithmImpl(
+ stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
+ computation_type, algorithm, output_profile_result);
+}
+
+bool CUDABlas::DoBlasGemmWithAlgorithm(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &b, int ldb,
+ std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
+ blas::ComputationType computation_type, blas::AlgorithmType algorithm,
+ blas::ProfileResult *output_profile_result) {
+ return DoBlasGemmWithAlgorithmImpl(
+ stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
+ computation_type, algorithm, output_profile_result);
+}
+
+bool CUDABlas::DoBlasGemmWithAlgorithm(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &b, int ldb,
+ std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
+ blas::ComputationType computation_type, blas::AlgorithmType algorithm,
+ blas::ProfileResult *output_profile_result) {
+ return DoBlasGemmWithAlgorithmImpl(
+ stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
+ computation_type, algorithm, output_profile_result);
+}
+
template <typename T, typename FuncT>
port::Status CUDABlas::DoBlasGemmBatchedInternal(
FuncT cublas_func, Stream *stream, blas::Transpose transa,