diff options
author | Yifei Feng <yifeif@google.com> | 2018-05-24 19:12:26 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-24 19:15:01 -0700 |
commit | b59833c3fd91511b33255369016868e4ae6cda2e (patch) | |
tree | ecbd70cfd3abb5d934f6eb4b7280a35e8589f5cf /tensorflow/stream_executor/cuda | |
parent | 2b99d9cbc7166efedaff9eee11744348da30fc8a (diff) |
Merge changes from github.
Revert #18413. Too many internal test failures due to the name scope change caused by this change.
Revert #18192. Cannot use re2::StringPiece internally. Need alternative for set call. Will pull and clean this up in a separate change.
PiperOrigin-RevId: 197991247
Diffstat (limited to 'tensorflow/stream_executor/cuda')
-rw-r--r-- | tensorflow/stream_executor/cuda/cuda_blas.cc | 106 | ||||
-rw-r--r-- | tensorflow/stream_executor/cuda/cuda_blas.h | 6 |
2 files changed, 96 insertions, 16 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc index 3e9a23c658..08fe153b59 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.cc +++ b/tensorflow/stream_executor/cuda/cuda_blas.cc @@ -286,6 +286,10 @@ STREAM_EXECUTOR_CUBLAS_WRAP(cublasGetMathMode) STREAM_EXECUTOR_CUBLAS_WRAP(cublasSetMathMode) #endif +#if CUDA_VERSION >= 9010 +STREAM_EXECUTOR_CUBLAS_WRAP(cublasGemmBatchedEx) +#endif + } // namespace wrap static string ToString(cublasStatus_t status) { @@ -2330,13 +2334,23 @@ bool CUDABlas::DoBlasGemmWithAlgorithm( computation_type, algorithm, output_profile_result); } -template <typename T, typename FuncT> +template <typename T> +struct HalfAsFloat { + typedef T type; +}; + +template <> +struct HalfAsFloat<Eigen::half> { + typedef float type; +}; + +template <typename T, typename Scalar, typename FuncT> port::Status CUDABlas::DoBlasGemmBatchedInternal( FuncT cublas_func, Stream *stream, blas::Transpose transa, - blas::Transpose transb, uint64 m, uint64 n, uint64 k, T alpha, + blas::Transpose transb, uint64 m, uint64 n, uint64 k, Scalar alpha, const port::ArraySlice<DeviceMemory<T> *> &a_ptrs_to_wrappers, int lda, const port::ArraySlice<DeviceMemory<T> *> &b_ptrs_to_wrappers, int ldb, - T beta, const port::ArraySlice<DeviceMemory<T> *> &c_ptrs_to_wrappers, + Scalar beta, const port::ArraySlice<DeviceMemory<T> *> &c_ptrs_to_wrappers, int ldc, int batch_count, ScratchAllocator *scratch_allocator) { std::vector<T *> a_raw_ptrs, b_raw_ptrs, c_raw_ptrs; for (int i = 0; i < batch_count; ++i) { @@ -2345,7 +2359,7 @@ port::Status CUDABlas::DoBlasGemmBatchedInternal( c_raw_ptrs.push_back(static_cast<T *>(c_ptrs_to_wrappers[i]->opaque())); } - typedef typename CUDAComplexT<T>::type CUDA_T; + typedef typename HalfAsFloat<typename CUDAComplexT<T>::type>::type CUDA_T; const size_t size = batch_count * sizeof(CUDA_T *); @@ -2397,18 +2411,84 @@ port::Status CUDABlas::DoBlasGemmBatchedInternal( "CUDABlas::DoBlasGemmBatched"); } - bool ok = DoBlasInternal( - cublas_func, stream, true /* = pointer_mode_host */, - CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, - CUDAComplex(&alpha), const_cast<const CUDA_T **>(CUDAMemory(a)), lda, - const_cast<const CUDA_T **>(CUDAMemory(b)), ldb, CUDAComplex(&beta), - const_cast<CUDA_T **>(CUDAMemory(c)), ldc, batch_count); + cudaDataType_t data_type = CUDADataType<T>::type; - if (ok) { +#if CUDA_VERSION >= 9010 + int cc_major, cc_minor; + if (stream->parent()->GetDeviceDescription().cuda_compute_capability( + &cc_major, &cc_minor) && + cc_major >= 5) { + bool use_tensor_ops = TensorOpMathEnabled() && data_type == CUDA_R_16F; + cublasGemmAlgo_t algo = + (use_tensor_ops ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DFALT); + cudaDataType_t compute_type = + (data_type == CUDA_R_16F ? CUDA_R_32F : data_type); + const void **a_void_ptrs = reinterpret_cast<const void **>( + const_cast<const CUDA_T **>(CUDAMemory(a))); + const void **b_void_ptrs = reinterpret_cast<const void **>( + const_cast<const CUDA_T **>(CUDAMemory(b))); + void **c_void_ptrs = + reinterpret_cast<void **>(const_cast<CUDA_T **>(CUDAMemory(c))); + bool ok; + ok = DoBlasInternalImpl( + wrap::cublasGemmBatchedEx, stream, true /* = pointer_mode_host */, + true /* = err_on_failure */, use_tensor_ops, CUDABlasTranspose(transa), + CUDABlasTranspose(transb), m, n, k, &alpha, a_void_ptrs, data_type, lda, + b_void_ptrs, data_type, ldb, &beta, c_void_ptrs, data_type, ldc, + batch_count, compute_type, algo); + if (ok) { + return port::Status::OK(); + } + return port::Status(port::error::INTERNAL, + "failed BLAS call, see log for details"); + } +#endif + // either CUDA_VERSION < 9.1 or SM < 5.0 + if (data_type != CUDA_R_16F) { + bool ok = DoBlasInternal( + cublas_func, stream, true /* = pointer_mode_host */, + CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, + CUDAComplex(&alpha), const_cast<const CUDA_T **>(CUDAMemory(a)), lda, + const_cast<const CUDA_T **>(CUDAMemory(b)), ldb, CUDAComplex(&beta), + const_cast<CUDA_T **>(CUDAMemory(c)), ldc, batch_count); + if (ok) { + return port::Status::OK(); + } + return port::Status(port::error::INTERNAL, + "failed BLAS call, see log for details"); + } else { + // Fall back to a loop for fp16 + for (int b = 0; b < batch_count; ++b) { + const DeviceMemory<T> &a_matrix = *a_ptrs_to_wrappers[b]; + const DeviceMemory<T> &b_matrix = *b_ptrs_to_wrappers[b]; + DeviceMemory<T> *c_matrix = c_ptrs_to_wrappers[b]; + bool ok = DoBlasGemm(stream, transa, transb, m, n, k, alpha, a_matrix, + lda, b_matrix, ldb, beta, c_matrix, ldc); + if (!ok) { + return port::Status(port::error::INTERNAL, + "failed BLAS call, see log for details"); + } + } return port::Status::OK(); } - return port::Status(port::error::INTERNAL, - "failed BLAS call, see log for details"); +} + +bool CUDABlas::DoBlasGemmBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, float alpha, + const port::ArraySlice<DeviceMemory<Eigen::half> *> &a_array, int lda, + const port::ArraySlice<DeviceMemory<Eigen::half> *> &b_array, int ldb, + float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c_array, + int ldc, int batch_count, ScratchAllocator *scratch_allocator) { + // Note: The func passed here (cublasSgemmBatched) is not actually called, + // due to special handling of fp16 inside DoBlasGemmBatchedInternal. + port::Status status = DoBlasGemmBatchedInternal( + wrap::cublasSgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, + lda, b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator); + if (!status.ok()) { + LOG(ERROR) << status; + } + return status.ok(); } bool CUDABlas::DoBlasGemmBatched( diff --git a/tensorflow/stream_executor/cuda/cuda_blas.h b/tensorflow/stream_executor/cuda/cuda_blas.h index 12dc5e47fd..42b3fde5b0 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.h +++ b/tensorflow/stream_executor/cuda/cuda_blas.h @@ -107,12 +107,12 @@ class CUDABlas : public blas::BlasSupport { // A helper function to implement DoBlasGemmBatched interfaces for generic // types. - template <typename T, typename FuncT> + template <typename T, typename Scalar, typename FuncT> port::Status DoBlasGemmBatchedInternal( FuncT cublas_func, Stream *stream, blas::Transpose transa, - blas::Transpose transb, uint64 m, uint64 n, uint64 k, T alpha, + blas::Transpose transb, uint64 m, uint64 n, uint64 k, Scalar alpha, const port::ArraySlice<DeviceMemory<T> *> &a_array, int lda, - const port::ArraySlice<DeviceMemory<T> *> &b_array, int ldb, T beta, + const port::ArraySlice<DeviceMemory<T> *> &b_array, int ldb, Scalar beta, const port::ArraySlice<DeviceMemory<T> *> &c_array, int ldc, int batch_count, ScratchAllocator *scratch_allocator); |