aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Nathan Luehr <nluehr@nvidia.com>2017-12-26 11:11:35 -0800
committerGravatar drpngx <drpngx@users.noreply.github.com>2017-12-26 11:11:35 -0800
commitbffa3e10bf4886f03a68f7e93ba39c91d447f101 (patch)
tree9294a3db2202581cda0a73d0fd9632815dfdf288
parent2f165f383045e6efbd472b52c33f6622cf164ec4 (diff)
Add support for CUBLAS_TENSOR_OP_MATH in fp16 GEMM (#13451)
- Applies to matrix multiplications with fp16 input/output. Computations will fall back to pseudo-fp16 if tensor op math is disabled or not supported. - Enabled by default. Tensor ops (both in cublas gemms and cudnn convolutions) can be disabled globally by setting the environment variable TF_DISABLE_TENSOR_OP_MATH=1. To disable tensor ops specifically for gemms or convolutions use TF_DISABLE_CUBLAS_TENSOR_OP_MATH=1 or TF_DISABLE_CUDNN_TENSOR_OP_MATH=1, respectively. - Added CUBLAS 9.0 algorithms to GetBlasGemmAlgorithms().
-rw-r--r--tensorflow/stream_executor/cuda/cuda_blas.cc156
-rw-r--r--tensorflow/stream_executor/cuda/cuda_blas.h10
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.cc9
3 files changed, 155 insertions, 20 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc
index cb2b06d47c..34a4fcff48 100644
--- a/tensorflow/stream_executor/cuda/cuda_blas.cc
+++ b/tensorflow/stream_executor/cuda/cuda_blas.cc
@@ -36,6 +36,7 @@ limitations under the License.
#include <assert.h>
#include <complex>
+#include "tensorflow/core/util/env_var.h"
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
#include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
#include "tensorflow/stream_executor/cuda/cuda_helpers.h"
@@ -268,6 +269,11 @@ PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasSgemmEx)
PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasGemmEx)
#endif
+#if CUDA_VERSION >= 9000
+PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasGetMathMode)
+PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasSetMathMode)
+#endif
+
} // namespace wrap
static string ToString(cublasStatus_t status) {
@@ -299,6 +305,18 @@ static string ToString(cublasStatus_t status) {
}
}
+// Decide whether to enable TENSOR_OP_MATH
+static bool TensorOpMathEnabled() {
+ static bool is_enabled = [] {
+ bool is_disabled;
+ TF_CHECK_OK(
+ tensorflow::ReadBoolFromEnvVar("TF_DISABLE_CUBLAS_TENSOR_OP_MATH",
+ /*default_val=*/false, &is_disabled));
+ return !is_disabled;
+ }();
+ return is_enabled;
+}
+
// cuBLAS has interfaces that permit pointers to be passed from either the host
// memory space or the device memory space; however, you must instruct it as to
// which address space those pointers are in with cublasSetPointerMode.
@@ -360,6 +378,66 @@ class ScopedCublasPointerMode {
bool ok_; // Whether the change was successful.
};
+#if CUDA_VERSION >= 9000
+// cuBLAS has interfaces that permit computations to use the Tensor Cores
+// available in Volta hardware. This must be enabled via the
+// cublasGet/SetMathMode APIs.
+//
+// This helper sets the cuBLAS math mode to a desired value for a cuBLAS call
+// you are about to perform in a given scope.
+//
+// The prior cuBLAS math mode is retained and restored when this object goes
+// out of scope.
+class ScopedCublasMathMode {
+ public:
+ // Note that, because the setting of the cublas math mode is fallible,
+ // construction of this scoped datatype must be paired with a call to
+ // Init().
+ //
+ // Parameters:
+ // handle: The cublas library handle to act upon in setting the math mode.
+ explicit ScopedCublasMathMode(CUDAExecutor *parent, cublasHandle_t handle)
+ : parent_(parent), handle_(handle), ok_(false) {}
+
+ // Attempts the switch to the requested scoped math mode, new_mode.
+ //
+ // Note that when false is returned, an appropriate error has already been
+ // logged.
+ bool Init(cublasMath_t new_mode) {
+ cublasStatus_t ret = wrap::cublasGetMathMode(parent_, handle_, &old_mode_);
+ if (ret != CUBLAS_STATUS_SUCCESS) {
+ LOG(ERROR) << "failed to get old cublas math mode: " << ToString(ret);
+ return ok_ = false;
+ }
+
+ ret = wrap::cublasSetMathMode(parent_, handle_, new_mode);
+ if (ret != CUBLAS_STATUS_SUCCESS) {
+ LOG(ERROR) << "failed to set new cublas math mode: " << ToString(ret);
+ return ok_ = false;
+ }
+ return ok_ = true;
+ }
+
+ // Switches back to the prior math mode, if the switch operation was
+ // successful in the first place.
+ ~ScopedCublasMathMode() {
+ if (ok_) {
+ cublasStatus_t ret = wrap::cublasSetMathMode(parent_, handle_, old_mode_);
+ if (ret != CUBLAS_STATUS_SUCCESS) {
+ LOG(ERROR) << "failed to set former cublas math mode: "
+ << ToString(ret);
+ }
+ }
+ }
+
+ private:
+ CUDAExecutor *parent_; // Executor establishing this math mode for.
+ cublasHandle_t handle_; // Handle to the cuBLAS instance of interest.
+ cublasMath_t old_mode_; // Prior cuBLAS math mode, to be restored.
+ bool ok_; // Whether the change was successful.
+};
+#endif // CUDA_VERSION >= 9000
+
bool CUDABlas::Init() {
cublasStatus_t ret = wrap::cublasCreate(parent_, &blas_);
if (ret != CUBLAS_STATUS_SUCCESS) {
@@ -532,7 +610,7 @@ cudaDataType_t CUDAComputationType(blas::ComputationType ty) {
template <typename FuncT, typename... Args>
bool CUDABlas::DoBlasInternalImpl(FuncT cublas_func, Stream *stream,
bool pointer_mode_host, bool err_on_failure,
- Args... args) {
+ bool use_tensor_op_math, Args... args) {
mutex_lock lock{mu_};
CHECK(blas_ != nullptr);
@@ -545,7 +623,14 @@ bool CUDABlas::DoBlasInternalImpl(FuncT cublas_func, Stream *stream,
: CUBLAS_POINTER_MODE_DEVICE)) {
return false;
}
-
+#if CUDA_VERSION >= 9000
+ ScopedCublasMathMode math_mode{parent_, blas_};
+ if (use_tensor_op_math) {
+ if (!math_mode.Init(CUBLAS_TENSOR_OP_MATH)) {
+ return false;
+ }
+ }
+#endif
cublasStatus_t ret = cublas_func(parent_, blas_, args...);
if (err_on_failure && ret != CUBLAS_STATUS_SUCCESS) {
LOG(ERROR) << "failed to run cuBLAS routine " << cublas_func.kName << ": "
@@ -1762,14 +1847,26 @@ bool CUDABlas::DoBlasGemm(
"precondition violation";
}
}
- // TODO(sesse): Consider supporting the Hgemm interface, which uses half
- // calculations internally (faster on newer devices, such as Pascal and TX1,
- // but less precise).
- return DoBlasInternal(
+
+ bool use_tensor_ops = false;
+#if CUDA_VERSION >= 9000
+ int cc_major, cc_minor;
+ stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major,
+ &cc_minor);
+
+ // GPUs < sm_70 don't support tensor cores
+ if (cc_major >= 7 && TensorOpMathEnabled()) {
+ use_tensor_ops = true;
+ }
+#endif
+
+ return DoBlasInternalImpl(
wrap::cublasSgemmEx, stream, true /* = pointer_mode_host */,
- CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha,
- CUDAMemory(a), SE_CUDA_DATA_HALF, lda, CUDAMemory(b), SE_CUDA_DATA_HALF,
- ldb, &beta, CUDAMemoryMutable(c), SE_CUDA_DATA_HALF, ldc);
+ true /* = err_on_failure= */, use_tensor_ops, CUDABlasTranspose(transa),
+ CUDABlasTranspose(transb), m, n, k, &alpha, CUDAMemory(a),
+ SE_CUDA_DATA_HALF, lda, CUDAMemory(b), SE_CUDA_DATA_HALF, ldb, &beta,
+ CUDAMemoryMutable(c), SE_CUDA_DATA_HALF, ldc);
+
#else
LOG(ERROR) << "fp16 sgemm is not implemented in this cuBLAS version "
<< "(need at least CUDA 7.5)";
@@ -2031,6 +2128,26 @@ bool CUDABlas::DoBlasGemmWithProfilingImpl(
return result;
}
+static bool UsesTensorOps(blas::AlgorithmType algo) {
+#if CUDA_VERSION >= 9000
+ cublasGemmAlgo_t cublas_algo = static_cast<cublasGemmAlgo_t>(algo);
+ return cublas_algo >= CUBLAS_GEMM_DEFAULT_TENSOR_OP;
+#else
+ return false;
+#endif
+}
+
+template <typename InType>
+static bool TensorOpsAvailable(int cc_major) {
+#if CUDA_VERSION >= 9000
+ if (cc_major >= 7 && TensorOpMathEnabled() &&
+ std::is_same<InType, Eigen::half>::value) {
+ return true;
+ }
+#endif
+ return false;
+}
+
template <typename InT, typename OutT, typename CompT>
bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
@@ -2049,6 +2166,10 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
return false;
}
+ if (UsesTensorOps(algorithm) && !TensorOpsAvailable<InT>(cc_major)) {
+ return false;
+ }
+
struct TimerDeleter {
void operator()(CUDATimer *t) {
t->Destroy();
@@ -2098,10 +2219,19 @@ bool CUDABlas::GetBlasGemmAlgorithms(
// still return the out_algorithms. Caller needs to make sure that in this case,
// the returned vector is empty.
#if CUDA_VERSION >= 8000
- for (cublasGemmAlgo_t algo :
- {CUBLAS_GEMM_DFALT, CUBLAS_GEMM_ALGO0, CUBLAS_GEMM_ALGO1,
- CUBLAS_GEMM_ALGO2, CUBLAS_GEMM_ALGO3, CUBLAS_GEMM_ALGO4,
- CUBLAS_GEMM_ALGO5, CUBLAS_GEMM_ALGO6, CUBLAS_GEMM_ALGO7}) {
+ for (cublasGemmAlgo_t algo : {
+ CUBLAS_GEMM_DFALT, CUBLAS_GEMM_ALGO0, CUBLAS_GEMM_ALGO1,
+ CUBLAS_GEMM_ALGO2, CUBLAS_GEMM_ALGO3, CUBLAS_GEMM_ALGO4,
+ CUBLAS_GEMM_ALGO5, CUBLAS_GEMM_ALGO6, CUBLAS_GEMM_ALGO7,
+#if CUDA_VERSION >= 9000
+ CUBLAS_GEMM_ALGO8, CUBLAS_GEMM_ALGO9, CUBLAS_GEMM_ALGO10,
+ CUBLAS_GEMM_ALGO11, CUBLAS_GEMM_ALGO12, CUBLAS_GEMM_ALGO13,
+ CUBLAS_GEMM_ALGO14, CUBLAS_GEMM_ALGO15, CUBLAS_GEMM_ALGO16,
+ CUBLAS_GEMM_ALGO17, CUBLAS_GEMM_DFALT_TENSOR_OP,
+ CUBLAS_GEMM_ALGO0_TENSOR_OP, CUBLAS_GEMM_ALGO1_TENSOR_OP,
+ CUBLAS_GEMM_ALGO2_TENSOR_OP
+#endif
+ }) {
out_algorithms->push_back(algo);
}
#endif
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.h b/tensorflow/stream_executor/cuda/cuda_blas.h
index 80cda97117..deb211c04b 100644
--- a/tensorflow/stream_executor/cuda/cuda_blas.h
+++ b/tensorflow/stream_executor/cuda/cuda_blas.h
@@ -84,7 +84,7 @@ class CUDABlas : public blas::BlasSupport {
template <typename FuncT, typename... Args>
bool DoBlasInternalImpl(FuncT cublas_func, Stream *stream,
bool pointer_mode_host, bool err_on_failure,
- Args... args);
+ bool use_tensor_op_math, Args... args);
// Convenience functions that call DoBlasInternalImpl with different values
// for err_on_failure.
@@ -92,13 +92,17 @@ class CUDABlas : public blas::BlasSupport {
bool DoBlasInternal(FuncT cublas_func, Stream *stream, bool pointer_mode_host,
Args... args) {
return DoBlasInternalImpl(cublas_func, stream, pointer_mode_host,
- /*err_on_failure=*/true, args...);
+ /*err_on_failure=*/true, /*use_tensor_ops=*/false,
+ args...);
}
template <typename FuncT, typename... Args>
bool DoBlasInternalFailureOK(FuncT cublas_func, Stream *stream,
bool pointer_mode_host, Args... args) {
+ // Tensor ops are hard-coded off in this path, but can still be enabled with
+ // a specific algorithm choice as in DoBlasGemmWithAlgorithmImpl().
return DoBlasInternalImpl(cublas_func, stream, pointer_mode_host,
- /*err_on_failure=*/false, args...);
+ /*err_on_failure=*/false,
+ /*use_tensor_ops=*/false, args...);
}
// A helper function to implement DoBlasGemmBatched interfaces for generic
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index 5519381d51..384445e6c1 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -559,10 +559,11 @@ class ScopedFilterDescriptor {
// A helper function to decide whether to enable the TENSOR_OP_MATH math type
static bool TensorOpMathEnabled() {
static bool is_enabled = [] {
- bool ret;
- TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_DISABLE_TENSOR_OP_MATH",
- /*default_val=*/false, &ret));
- return !ret;
+ bool is_disabled;
+ TF_CHECK_OK(
+ tensorflow::ReadBoolFromEnvVar("TF_DISABLE_CUDNN_TENSOR_OP_MATH",
+ /*default_val=*/false, &is_disabled));
+ return !is_disabled;
}();
return is_enabled;
}