diff options
Diffstat (limited to 'tensorflow/stream_executor')
-rw-r--r-- | tensorflow/stream_executor/blas.h | 66 | ||||
-rw-r--r-- | tensorflow/stream_executor/cuda/cuda_blas.cc | 217 | ||||
-rw-r--r-- | tensorflow/stream_executor/cuda/cuda_dnn.cc | 78 | ||||
-rw-r--r-- | tensorflow/stream_executor/cuda/cuda_driver.cc | 16 | ||||
-rw-r--r-- | tensorflow/stream_executor/dnn.h | 15 | ||||
-rw-r--r-- | tensorflow/stream_executor/stream.cc | 304 | ||||
-rw-r--r-- | tensorflow/stream_executor/stream.h | 44 | ||||
-rw-r--r-- | tensorflow/stream_executor/stream_executor_internal.cc | 12 | ||||
-rw-r--r-- | tensorflow/stream_executor/stream_executor_internal.h | 2 | ||||
-rw-r--r-- | tensorflow/stream_executor/stream_executor_pimpl.cc | 5 | ||||
-rw-r--r-- | tensorflow/stream_executor/stream_executor_pimpl.h | 5 | ||||
-rw-r--r-- | tensorflow/stream_executor/stream_test.cc | 90 |
12 files changed, 722 insertions, 132 deletions
diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h index ea87744b22..7f851e3646 100644 --- a/tensorflow/stream_executor/blas.h +++ b/tensorflow/stream_executor/blas.h @@ -1121,6 +1121,40 @@ class BlasSupport { const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc, int batch_count, ScratchAllocator *scratch_allocator) = 0; + // Batched gemm with strides instead of pointer arrays. + virtual bool DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, + int lda, int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb, + int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc, + int64 stride_c, int batch_count) = 0; + virtual bool DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda, + int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b, + float beta, DeviceMemory<float> *c, int ldc, int64 stride_c, + int batch_count) = 0; + virtual bool DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda, + int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b, + double beta, DeviceMemory<double> *c, int ldc, int64 stride_c, + int batch_count) = 0; + virtual bool DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, std::complex<float> alpha, + const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a, + const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b, + std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, + int64 stride_c, int batch_count) = 0; + virtual bool DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, std::complex<double> alpha, + const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a, + const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b, + std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc, + int64 stride_c, int batch_count) = 0; + // Computes a matrix-matrix product where one input matrix is Hermitian: // // c <- alpha * a * b + beta * c, @@ -1990,6 +2024,38 @@ class BlasSupport { int ldb, std::complex<double> beta, \ const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, \ int ldc, int batch_count, ScratchAllocator *scratch_allocator) override; \ + bool DoBlasGemmStridedBatched( \ + Stream *stream, blas::Transpose transa, blas::Transpose transb, \ + uint64 m, uint64 n, uint64 k, float alpha, \ + const DeviceMemory<Eigen::half> &a, int lda, int64 stride_a, \ + const DeviceMemory<Eigen::half> &b, int ldb, int64 stride_b, float beta, \ + DeviceMemory<Eigen::half> *c, int ldc, int64 stride_c, int batch_count); \ + bool DoBlasGemmStridedBatched( \ + Stream *stream, blas::Transpose transa, blas::Transpose transb, \ + uint64 m, uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, \ + int lda, int64 stride_a, const DeviceMemory<float> &b, int ldb, \ + int64 stride_b, float beta, DeviceMemory<float> *c, int ldc, \ + int64 stride_c, int batch_count); \ + bool DoBlasGemmStridedBatched( \ + Stream *stream, blas::Transpose transa, blas::Transpose transb, \ + uint64 m, uint64 n, uint64 k, double alpha, \ + const DeviceMemory<double> &a, int lda, int64 stride_a, \ + const DeviceMemory<double> &b, int ldb, int64 stride_b, double beta, \ + DeviceMemory<double> *c, int ldc, int64 stride_c, int batch_count); \ + bool DoBlasGemmStridedBatched( \ + Stream *stream, blas::Transpose transa, blas::Transpose transb, \ + uint64 m, uint64 n, uint64 k, std::complex<float> alpha, \ + const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a, \ + const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b, \ + std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, \ + int64 stride_c, int batch_count); \ + bool DoBlasGemmStridedBatched( \ + Stream *stream, blas::Transpose transa, blas::Transpose transb, \ + uint64 m, uint64 n, uint64 k, std::complex<double> alpha, \ + const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a, \ + const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b, \ + std::complex<double> beta, DeviceMemory<std::complex<double>> *c, \ + int ldc, int64 stride_c, int batch_count); \ bool DoBlasHemm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ uint64 m, uint64 n, std::complex<float> alpha, \ const DeviceMemory<std::complex<float>> &a, int lda, \ diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc index 874bf0e8cb..ab7091b3f5 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.cc +++ b/tensorflow/stream_executor/cuda/cuda_blas.cc @@ -279,6 +279,10 @@ STREAM_EXECUTOR_CUBLAS_WRAP(cublasSgemmEx) #if CUDA_VERSION >= 8000 STREAM_EXECUTOR_CUBLAS_WRAP(cublasGemmEx) +STREAM_EXECUTOR_CUBLAS_WRAP(cublasSgemmStridedBatched) +STREAM_EXECUTOR_CUBLAS_WRAP(cublasDgemmStridedBatched) +STREAM_EXECUTOR_CUBLAS_WRAP(cublasCgemmStridedBatched) +STREAM_EXECUTOR_CUBLAS_WRAP(cublasZgemmStridedBatched) #endif #if CUDA_VERSION >= 9000 @@ -288,6 +292,7 @@ STREAM_EXECUTOR_CUBLAS_WRAP(cublasSetMathMode) #if CUDA_VERSION >= 9010 STREAM_EXECUTOR_CUBLAS_WRAP(cublasGemmBatchedEx) +STREAM_EXECUTOR_CUBLAS_WRAP(cublasGemmStridedBatchedEx) #endif } // namespace wrap @@ -643,7 +648,7 @@ bool CUDABlas::DoBlasInternalImpl(FuncT cublas_func, Stream *stream, } #endif cublasStatus_t ret = cublas_func(parent_, blas_, args...); - if (err_on_failure && ret != CUBLAS_STATUS_SUCCESS) { + if ((err_on_failure || VLOG_IS_ON(3)) && ret != CUBLAS_STATUS_SUCCESS) { LOG(ERROR) << "failed to run cuBLAS routine " << cublas_func.kName << ": " << ToString(ret); } @@ -1865,7 +1870,7 @@ bool CUDABlas::DoBlasGemm( stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor); - // GPUs < sm_70 don't support Volta hardware. + // GPUs < sm_70 don't support tensor ops. if (cc_major >= 7 && TensorOpMathEnabled()) { use_tensor_ops = true; } @@ -2139,6 +2144,10 @@ static bool UsesTensorOps(blas::AlgorithmType algo) { template <typename InType> static bool TensorOpsAvailable(int cc_major) { #if CUDA_VERSION >= 9000 + // cublas *does* allow tensor ops on inputs that are not fp16, so this is not + // strictly correct. We can't simply enable it, though, as that would change + // clients' behavior significantly: Using tensor ops on fp32 inputs cause them + // to be rounded to fp16. if (cc_major >= 7 && TensorOpMathEnabled() && std::is_same<InType, Eigen::half>::value) { return true; @@ -2160,16 +2169,30 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl( if (stream->parent()->GetDeviceDescription().cuda_compute_capability( &cc_major, &cc_minor) && cc_major < 5) { + VLOG(2) << "DoBlasGemmWithAlgorithm returning false because sm" << cc_major + << cc_minor << " devices don't support explicit gemm algorithms."; return false; } if (UsesTensorOps(algorithm) && !TensorOpsAvailable<InT>(cc_major)) { + if (std::is_same<InT, Eigen::half>::value) { + VLOG(2) << "DoBlasGemmWithAlgorithm returning false because algorithm " + << algorithm + << " uses tensor ops, but tensor ops are not available in sm" + << cc_major << "X devices."; + } else { + VLOG(2) << "DoBlasGemmWithAlgorithm returning false because algorithm " + << algorithm + << " uses tensor ops, but the input data type is not fp16."; + } return false; } // Either both 'alpha' and 'beta' need to be pointers to device memory, or // they need to be both host scalars. if (alpha.is_pointer() != beta.is_pointer()) { + VLOG(2) << "DoBlasGemmWithAlgorithm returning false because one of `alpha` " + "and `beta` is a pointer, but the other is not."; return false; } @@ -2177,6 +2200,9 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl( if (output_profile_result != nullptr) { timer.reset(new CUDATimer(parent_)); if (!timer->Init() || !timer->Start(AsCUDAStream(stream))) { + VLOG(2) << "DoBlasGemmWithAlgorithm returning false because " + "output_profile_result was given, but we were unable to " + "create a CUDATimer."; return false; } } @@ -2186,6 +2212,8 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl( #if CUDA_VERSION >= 9000 && CUDA_VERSION < 9020 if ((algorithm == CUBLAS_GEMM_DEFAULT || algorithm >= CUBLAS_GEMM_ALGO13) && std::max({m, n, k}) >= 2097153 && cc_major < 7) { + VLOG(2) << "DoBlasGemmWithAlgorithm returning false to work around cudnn " + "<9.2 bug with m, n, or k >= 2097153. See b/79126339."; return false; } #endif @@ -2211,6 +2239,8 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl( // CUDATimer will CHECK-fail if we Stop() it while the stream is in an error // state. if (!timer->Stop(AsCUDAStream(stream))) { + VLOG(2) << "DoBlasGemmWithAlgorithm returning false; unable to stop " + "CUDATimer."; return false; } output_profile_result->set_is_valid(true); @@ -2223,26 +2253,60 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl( bool CUDABlas::GetBlasGemmAlgorithms( std::vector<blas::AlgorithmType> *out_algorithms) { -// cublasGemmAlgo_t (and the function that accepts this type, cublasGemmEx) -// were first introduced in CUDA 8. -// Note that when CUDA version and compute capability is not sufficient, we -// still return the out_algorithms. Caller needs to make sure that in this case, -// the returned vector is empty. - for (cublasGemmAlgo_t algo : { - CUBLAS_GEMM_DFALT, CUBLAS_GEMM_ALGO0, CUBLAS_GEMM_ALGO1, - CUBLAS_GEMM_ALGO2, CUBLAS_GEMM_ALGO3, CUBLAS_GEMM_ALGO4, - CUBLAS_GEMM_ALGO5, CUBLAS_GEMM_ALGO6, CUBLAS_GEMM_ALGO7, + // cublasGemmAlgo_t (and the function that accepts this type, cublasGemmEx) + // were first introduced in CUDA 8. + // + // Note that when CUDA version and compute capability is not sufficient, we + // still return the out_algorithms. Caller needs to make sure that in this + // case, the returned vector is empty. + *out_algorithms = { + CUBLAS_GEMM_DFALT, + CUBLAS_GEMM_ALGO0, + CUBLAS_GEMM_ALGO1, + CUBLAS_GEMM_ALGO2, + CUBLAS_GEMM_ALGO3, + CUBLAS_GEMM_ALGO4, + CUBLAS_GEMM_ALGO5, + CUBLAS_GEMM_ALGO6, + CUBLAS_GEMM_ALGO7, #if CUDA_VERSION >= 9000 - CUBLAS_GEMM_ALGO8, CUBLAS_GEMM_ALGO9, CUBLAS_GEMM_ALGO10, - CUBLAS_GEMM_ALGO11, CUBLAS_GEMM_ALGO12, CUBLAS_GEMM_ALGO13, - CUBLAS_GEMM_ALGO14, CUBLAS_GEMM_ALGO15, CUBLAS_GEMM_ALGO16, - CUBLAS_GEMM_ALGO17, CUBLAS_GEMM_DFALT_TENSOR_OP, - CUBLAS_GEMM_ALGO0_TENSOR_OP, CUBLAS_GEMM_ALGO1_TENSOR_OP, - CUBLAS_GEMM_ALGO2_TENSOR_OP + CUBLAS_GEMM_ALGO8, + CUBLAS_GEMM_ALGO9, + CUBLAS_GEMM_ALGO10, + CUBLAS_GEMM_ALGO11, + CUBLAS_GEMM_ALGO12, + CUBLAS_GEMM_ALGO13, + CUBLAS_GEMM_ALGO14, + CUBLAS_GEMM_ALGO15, + CUBLAS_GEMM_ALGO16, + CUBLAS_GEMM_ALGO17, + CUBLAS_GEMM_DFALT_TENSOR_OP, + CUBLAS_GEMM_ALGO0_TENSOR_OP, + CUBLAS_GEMM_ALGO1_TENSOR_OP, + CUBLAS_GEMM_ALGO2_TENSOR_OP, + CUBLAS_GEMM_ALGO3_TENSOR_OP, + CUBLAS_GEMM_ALGO4_TENSOR_OP, #endif - }) { - out_algorithms->push_back(algo); - } +#if CUDA_VERSION >= 9200 + CUBLAS_GEMM_ALGO18, + CUBLAS_GEMM_ALGO19, + CUBLAS_GEMM_ALGO20, + CUBLAS_GEMM_ALGO21, + CUBLAS_GEMM_ALGO22, + CUBLAS_GEMM_ALGO23, + CUBLAS_GEMM_ALGO5_TENSOR_OP, + CUBLAS_GEMM_ALGO6_TENSOR_OP, + CUBLAS_GEMM_ALGO7_TENSOR_OP, + CUBLAS_GEMM_ALGO8_TENSOR_OP, + CUBLAS_GEMM_ALGO9_TENSOR_OP, + CUBLAS_GEMM_ALGO10_TENSOR_OP, + CUBLAS_GEMM_ALGO11_TENSOR_OP, + CUBLAS_GEMM_ALGO12_TENSOR_OP, + CUBLAS_GEMM_ALGO13_TENSOR_OP, + CUBLAS_GEMM_ALGO14_TENSOR_OP, + CUBLAS_GEMM_ALGO15_TENSOR_OP, +#endif + }; return true; } @@ -2564,6 +2628,119 @@ bool CUDABlas::DoBlasGemmBatched( return status.ok(); } +bool CUDABlas::DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, + int lda, int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb, + int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc, + int64 stride_c, int batch_count) { + bool use_tensor_ops = false; +#if CUDA_VERSION >= 9000 + int cc_major, cc_minor; + if (stream->parent()->GetDeviceDescription().cuda_compute_capability( + &cc_major, &cc_minor)) { + // GPUs < sm_70 don't support tensor ops. + if (cc_major >= 7 && TensorOpMathEnabled()) { + use_tensor_ops = true; + } +#if CUDA_VERSION >= 9010 + if (cc_major >= 5) { + cublasGemmAlgo_t algo = + (use_tensor_ops ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DFALT); + bool ok = DoBlasInternalImpl( + wrap::cublasGemmStridedBatchedEx, stream, + true /* = pointer_mode_host */, true /* = err_on_failure */, + use_tensor_ops, CUDABlasTranspose(transa), CUDABlasTranspose(transb), + m, n, k, &alpha, CUDAMemory(a), CUDA_R_16F, lda, stride_a, + CUDAMemory(b), CUDA_R_16F, ldb, stride_b, &beta, CUDAMemoryMutable(c), + CUDA_R_16F, ldc, stride_c, batch_count, CUDA_R_32F, algo); + if (ok) { + return true; + } + LOG(ERROR) << "failed BLAS call, see log for details"; + return false; + } +#endif + } +#endif + // Either CUDA_VERSION < 9.1 or SM < 5.0. Fall back to a loop. + for (int batch = 0; batch < batch_count; ++batch) { + const auto *a_matrix = + reinterpret_cast<const __half *>(CUDAMemory(a) + batch * stride_a); + const auto *b_matrix = + reinterpret_cast<const __half *>(CUDAMemory(b) + batch * stride_b); + auto *c_matrix = + reinterpret_cast<__half *>(CUDAMemoryMutable(c) + batch * stride_c); + bool ok = DoBlasInternalImpl( + wrap::cublasSgemmEx, stream, true /* = pointer_mode_host */, + true /* = err_on_failure= */, use_tensor_ops, CUDABlasTranspose(transa), + CUDABlasTranspose(transb), m, n, k, &alpha, a_matrix, SE_CUDA_DATA_HALF, + lda, b_matrix, SE_CUDA_DATA_HALF, ldb, &beta, c_matrix, + SE_CUDA_DATA_HALF, ldc); + if (!ok) { + LOG(ERROR) << "failed BLAS call, see log for details"; + return false; + } + } + return true; +} + +bool CUDABlas::DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda, + int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b, + float beta, DeviceMemory<float> *c, int ldc, int64 stride_c, + int batch_count) { + return DoBlasInternal( + wrap::cublasSgemmStridedBatched, stream, true /* = pointer_mode_host */, + CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha, + CUDAMemory(a), lda, stride_a, CUDAMemory(b), ldb, stride_b, &beta, + CUDAMemoryMutable(c), ldc, stride_c, batch_count); +} + +bool CUDABlas::DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda, + int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b, + double beta, DeviceMemory<double> *c, int ldc, int64 stride_c, + int batch_count) { + return DoBlasInternal( + wrap::cublasDgemmStridedBatched, stream, true /* = pointer_mode_host */, + CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha, + CUDAMemory(a), lda, stride_a, CUDAMemory(b), ldb, stride_b, &beta, + CUDAMemoryMutable(c), ldc, stride_c, batch_count); +} + +bool CUDABlas::DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, std::complex<float> alpha, + const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a, + const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b, + std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, + int64 stride_c, int batch_count) { + return DoBlasInternal( + wrap::cublasCgemmStridedBatched, stream, true /* = pointer_mode_host */, + CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, + CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda, stride_a, + CUDAComplex(CUDAMemory(b)), ldb, stride_b, CUDAComplex(&beta), + CUDAComplex(CUDAMemoryMutable(c)), ldc, stride_c, batch_count); +} + +bool CUDABlas::DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, std::complex<double> alpha, + const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a, + const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b, + std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc, + int64 stride_c, int batch_count) { + return DoBlasInternal( + wrap::cublasZgemmStridedBatched, stream, true /* = pointer_mode_host */, + CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, + CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda, stride_a, + CUDAComplex(CUDAMemory(b)), ldb, stride_b, CUDAComplex(&beta), + CUDAComplex(CUDAMemoryMutable(c)), ldc, stride_c, batch_count); +} + bool CUDABlas::DoBlasHemm(Stream *stream, blas::Side side, blas::UpperLower uplo, uint64 m, uint64 n, std::complex<float> alpha, diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index 1c3940e92c..55408ab9ab 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -1986,15 +1986,14 @@ GetCudnnConvolutionBackwardFilterAlgo(const CudnnHandle& cudnn, port::StatusOr<DeviceMemory<uint8>> AllocateCudnnConvolutionForwardWorkspace( Stream* stream, const CudnnHandle& cudnn, - const dnn::AlgorithmDesc& algorithm_desc, const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter, const CudnnConvolutionDescriptor& conv, - const CudnnTensorDescriptor& output_nd, + const CudnnTensorDescriptor& output_nd, dnn::AlgorithmDesc* algorithm_desc, ScratchAllocator* scratch_allocator) { // TODO(csigg): This has side effects on the convolution descriptor. It is // functionally correct because the convolution is run with the algorithm of // the last call to this function, but should be fixed anyway. - conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled()); + conv.set_use_tensor_op_math(algorithm_desc->tensor_ops_enabled()); // Query the size of the workspace and allocate it. size_t size_in_bytes; @@ -2002,8 +2001,14 @@ port::StatusOr<DeviceMemory<uint8>> AllocateCudnnConvolutionForwardWorkspace( cudnn.handle(), /*xDesc=*/input_nd.handle(), /*wDesc=*/filter.handle(), /*convDesc=*/conv.handle(), - /*yDesc=*/output_nd.handle(), /*algo=*/ToConvForwardAlgo(algorithm_desc), + /*yDesc=*/output_nd.handle(), /*algo=*/ToConvForwardAlgo(*algorithm_desc), /*sizeInBytes=*/&size_in_bytes)); + + if (TF_PREDICT_FALSE(!algorithm_desc)) { + return port::Status(port::error::INVALID_ARGUMENT, + "No AlgorithmDesc provided"); + } + algorithm_desc->set_scratch_size(size_in_bytes); int64 size_in_bytes_int64 = size_in_bytes; if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) { @@ -2028,15 +2033,14 @@ port::StatusOr<DeviceMemory<uint8>> AllocateCudnnConvolutionForwardWorkspace( port::StatusOr<DeviceMemory<uint8>> AllocateCudnnConvolutionBackwardDataWorkspace( Stream* stream, const CudnnHandle& cudnn, - const dnn::AlgorithmDesc& algorithm_desc, const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter, const CudnnConvolutionDescriptor& conv, - const CudnnTensorDescriptor& output_nd, + const CudnnTensorDescriptor& output_nd, dnn::AlgorithmDesc* algorithm_desc, ScratchAllocator* scratch_allocator) { // TODO(csigg): This has side effects on the convolution descriptor. It is // functionally correct because the convolution is run with the algorithm of // the last call to this function, but should be fixed anyway. - conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled()); + conv.set_use_tensor_op_math(algorithm_desc->tensor_ops_enabled()); // Query the size of the workspace and allocate it. size_t size_in_bytes; @@ -2046,8 +2050,14 @@ AllocateCudnnConvolutionBackwardDataWorkspace( /*dyDesc=*/output_nd.handle(), /*convDesc=*/conv.handle(), /*dxDesc=*/input_nd.handle(), - /*algo=*/ToConvBackwardDataAlgo(algorithm_desc), + /*algo=*/ToConvBackwardDataAlgo(*algorithm_desc), /*sizeInBytes=*/&size_in_bytes)); + + if (TF_PREDICT_FALSE(!algorithm_desc)) { + return port::Status(port::error::INVALID_ARGUMENT, + "No AlgorithmDesc provided"); + } + algorithm_desc->set_scratch_size(size_in_bytes); int64 size_in_bytes_int64 = size_in_bytes; if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) { @@ -2072,15 +2082,14 @@ AllocateCudnnConvolutionBackwardDataWorkspace( port::StatusOr<DeviceMemory<uint8>> AllocateCudnnConvolutionBackwardFilterWorkspace( Stream* stream, const CudnnHandle& cudnn, - const dnn::AlgorithmDesc& algorithm_desc, const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter, const CudnnConvolutionDescriptor& conv, - const CudnnTensorDescriptor& output_nd, + const CudnnTensorDescriptor& output_nd, dnn::AlgorithmDesc* algorithm_desc, ScratchAllocator* scratch_allocator) { // TODO(csigg): This has side effects on the convolution descriptor. It is // functionally correct because the convolution is run with the algorithm of // the last call to this function, but should be fixed anyway. - conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled()); + conv.set_use_tensor_op_math(algorithm_desc->tensor_ops_enabled()); // Query the size of the workspace and allocate it. size_t size_in_bytes; @@ -2090,8 +2099,14 @@ AllocateCudnnConvolutionBackwardFilterWorkspace( /*dyDesc=*/output_nd.handle(), /*convDesc=*/conv.handle(), /*gradDesc=*/filter.handle(), - /*algo=*/ToConvBackwardFilterAlgo(algorithm_desc), + /*algo=*/ToConvBackwardFilterAlgo(*algorithm_desc), /*sizeInBytes=*/&size_in_bytes)); + + if (TF_PREDICT_FALSE(!algorithm_desc)) { + return port::Status(port::error::INVALID_ARGUMENT, + "No AlgorithmDesc provided"); + } + algorithm_desc->set_scratch_size(size_in_bytes); int64 size_in_bytes_int64 = size_in_bytes; if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) { @@ -2138,7 +2153,7 @@ port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionForwardAlgorithm( } auto scratch_or = AllocateCudnnConvolutionForwardWorkspace( - stream, cudnn, algo_desc, input_nd, filter, conv, output_nd, + stream, cudnn, input_nd, filter, conv, output_nd, &algo_desc, scratch_allocator); if (scratch_or.ok()) { @@ -2155,11 +2170,11 @@ port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionForwardAlgorithm( "while a secondary algorithm is not provided."); } - SE_ASSIGN_OR_RETURN( - *scratch, AllocateCudnnConvolutionForwardWorkspace( - stream, cudnn, algorithm_config.algorithm_no_scratch(), - input_nd, filter, conv, output_nd, scratch_allocator)); - return algorithm_config.algorithm_no_scratch(); + algo_desc = algorithm_config.algorithm_no_scratch(); + SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionForwardWorkspace( + stream, cudnn, input_nd, filter, conv, + output_nd, &algo_desc, scratch_allocator)); + return algo_desc; } port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardDataAlgorithm( @@ -2187,7 +2202,7 @@ port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardDataAlgorithm( } auto scratch_or = AllocateCudnnConvolutionBackwardDataWorkspace( - stream, cudnn, algo_desc, input_nd, filter, conv, output_nd, + stream, cudnn, input_nd, filter, conv, output_nd, &algo_desc, scratch_allocator); if (scratch_or.ok()) { @@ -2204,11 +2219,11 @@ port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardDataAlgorithm( "while a secondary algorithm is not provided."); } - SE_ASSIGN_OR_RETURN( - *scratch, AllocateCudnnConvolutionBackwardDataWorkspace( - stream, cudnn, algorithm_config.algorithm_no_scratch(), - input_nd, filter, conv, output_nd, scratch_allocator)); - return algorithm_config.algorithm_no_scratch(); + algo_desc = algorithm_config.algorithm_no_scratch(); + SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionBackwardDataWorkspace( + stream, cudnn, input_nd, filter, conv, + output_nd, &algo_desc, scratch_allocator)); + return algo_desc; } port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardFilterAlgorithm( @@ -2236,7 +2251,7 @@ port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardFilterAlgorithm( } auto scratch_or = AllocateCudnnConvolutionBackwardFilterWorkspace( - stream, cudnn, algo_desc, input_nd, filter, conv, output_nd, + stream, cudnn, input_nd, filter, conv, output_nd, &algo_desc, scratch_allocator); if (scratch_or.ok()) { @@ -2253,11 +2268,11 @@ port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardFilterAlgorithm( "while a secondary algorithm is not provided."); } - SE_ASSIGN_OR_RETURN(*scratch, - AllocateCudnnConvolutionBackwardFilterWorkspace( - stream, cudnn, algorithm_config.algorithm(), input_nd, - filter, conv, output_nd, scratch_allocator)); - return algorithm_config.algorithm_no_scratch(); + algo_desc = algorithm_config.algorithm_no_scratch(); + SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionBackwardFilterWorkspace( + stream, cudnn, input_nd, filter, conv, + output_nd, &algo_desc, scratch_allocator)); + return algo_desc; } // A helper class to set env-vars and choose options for cudnn-related @@ -3082,8 +3097,7 @@ port::Status CudnnSupport::DoConvolveBackwardDataImpl( } // Cudnn 7.1.4 has a bug if the workspace of the following convolution is not - // zero-initialized. - // TODO(timshen): Add an nvbugs/ link. + // zero-initialized, nvbugs/2254619. if (CUDNN_VERSION >= 7000 && algorithm_config.algorithm().algo_id() == CUDNN_CONVOLUTION_BWD_DATA_ALGO_1 && diff --git a/tensorflow/stream_executor/cuda/cuda_driver.cc b/tensorflow/stream_executor/cuda/cuda_driver.cc index dbece3adf9..f982f34b98 100644 --- a/tensorflow/stream_executor/cuda/cuda_driver.cc +++ b/tensorflow/stream_executor/cuda/cuda_driver.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/stream_executor/lib/human_readable.h" #include "tensorflow/stream_executor/lib/inlined_vector.h" #include "tensorflow/stream_executor/lib/notification.h" +#include "tensorflow/stream_executor/lib/ptr_util.h" #include "tensorflow/stream_executor/lib/stacktrace.h" #include "tensorflow/stream_executor/lib/static_threadlocal.h" #include "tensorflow/stream_executor/lib/strcat.h" @@ -66,14 +67,17 @@ class CreatedContexts { return Live()->find(context) != Live()->end(); } - // Adds context to the live set. + // Adds context to the live set, or returns it if it's already present. static CudaContext* Add(CUcontext context) { CHECK(context != nullptr); mutex_lock lock(mu_); - auto cuda_context = new CudaContext(context, next_id_++); - Live()->insert( - std::make_pair(context, std::unique_ptr<CudaContext>(cuda_context))); - return cuda_context; + auto insert_result = Live()->insert(std::make_pair(context, nullptr)); + auto it = insert_result.first; + if (insert_result.second) { + // context was not present in the map. Add it. + it->second = MakeUnique<CudaContext>(context, next_id_++); + } + return it->second.get(); } // Removes context from the live set. @@ -427,7 +431,7 @@ bool DeviceOptionsToContextFlags(const DeviceOptions &device_options, *context = CreatedContexts::Add(new_context); CHECK(*context != nullptr) << "success in this call must entail non-null result"; - VLOG(2) << "created context " << context << " for this thread"; + VLOG(2) << "created or reused context " << context << " for this thread"; return port::Status::OK(); } diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h index a7449c2df4..9abfa1db6a 100644 --- a/tensorflow/stream_executor/dnn.h +++ b/tensorflow/stream_executor/dnn.h @@ -713,15 +713,23 @@ class PoolingDescriptor { class AlgorithmDesc { public: typedef int64 Index; - AlgorithmDesc() : algo_(kDefaultAlgorithm), tensor_ops_enabled_(true) {} + AlgorithmDesc() + : algo_(kDefaultAlgorithm), tensor_ops_enabled_(true), scratch_size_(0) {} AlgorithmDesc(Index a, bool use_tensor_ops) - : algo_(a), tensor_ops_enabled_(use_tensor_ops) {} + : algo_(a), tensor_ops_enabled_(use_tensor_ops), scratch_size_(0) {} + AlgorithmDesc(Index a, bool use_tensor_ops, size_t scratch_size) + : algo_(a), + tensor_ops_enabled_(use_tensor_ops), + scratch_size_(scratch_size) {} bool is_default() const { return algo_ == kDefaultAlgorithm; } bool tensor_ops_enabled() const { return tensor_ops_enabled_; } Index algo_id() const { return algo_; } + size_t scratch_size() const { return scratch_size_; } + void set_scratch_size(size_t val) { scratch_size_ = val; } bool operator==(const AlgorithmDesc& other) const { return this->algo_ == other.algo_ && - this->tensor_ops_enabled_ == other.tensor_ops_enabled_; + this->tensor_ops_enabled_ == other.tensor_ops_enabled_ && + this->scratch_size_ == other.scratch_size_; } uint64 hash() const; @@ -729,6 +737,7 @@ class AlgorithmDesc { enum { kDefaultAlgorithm = -1 }; Index algo_; bool tensor_ops_enabled_; + size_t scratch_size_; }; // Describes the result from a perf experiment. diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc index 6248aa2d01..9efd34de24 100644 --- a/tensorflow/stream_executor/stream.cc +++ b/tensorflow/stream_executor/stream.cc @@ -115,7 +115,7 @@ string ToVlogString(const DeviceMemoryBase &memory) { } string ToVlogString(const DeviceMemoryBase *memory) { - return ToVlogString(*memory); + return memory == nullptr ? "null" : ToVlogString(*memory); } string ToVlogString(const Eigen::half &h) { @@ -211,13 +211,14 @@ string CallStr(const char *function_name, Stream *stream, // constructing all the strings in params is expensive. CHECK(VLOG_IS_ON(1)); - string str = port::StrCat("Called Stream::", function_name, "("); + string str = port::StrCat(stream->DebugStreamPointers(), + " Called Stream::", function_name, "("); const char *separator = ""; for (const auto ¶m : params) { port::StrAppend(&str, separator, param.first, "=", param.second); separator = ", "; } - port::StrAppend(&str, ") stream=", ToVlogString(stream)); + port::StrAppend(&str, ")"); if (VLOG_IS_ON(10)) { port::StrAppend(&str, " ", port::CurrentStackTrace(), "\n"); } @@ -267,13 +268,13 @@ Stream::Stream(StreamExecutor *parent, Stream::~Stream() { VLOG_CALL(); - temporary_memory_manager_.ForceDeallocateAll(); // Ensure the stream is completed. auto status = BlockHostUntilDone(); if (!status.ok()) { LOG(WARNING) << "Error blocking host until done in stream destructor: " << status; } + temporary_memory_manager_.ForceDeallocateAll(); if (allocated_) { parent_->DeallocateStream(this); @@ -1922,37 +1923,82 @@ Stream &Stream::ThenCopyDevice2HostBuffer( Stream *Stream::GetOrCreateSubStream() { mutex_lock lock(mu_); - for (auto &stream : sub_streams_) { - if (stream.second) { - stream.second = false; - return stream.first.get(); + + // Look for the first reusable sub_stream that is ok, dropping !ok sub_streams + // we encounter along the way. + for (int64 index = 0; index < sub_streams_.size();) { + std::pair<std::unique_ptr<Stream>, bool> &pair = sub_streams_[index]; + if (pair.second) { + // The sub_stream is reusable. + Stream *sub_stream = pair.first.get(); + if (sub_stream->ok()) { + VLOG(1) << DebugStreamPointers() << " reusing sub_stream " + << sub_stream->DebugStreamPointers(); + pair.second = false; + return sub_stream; + } + + // The stream is reusable and not ok. Streams have a monotonic state + // machine; the stream will remain in !ok forever. Swap it with the last + // stream and pop it off. + const int64 last = sub_streams_.size() - 1; + if (index != last) { + std::swap(pair, sub_streams_[last]); + } + sub_streams_.pop_back(); + VLOG(1) << DebugStreamPointers() << " dropped !ok sub_stream " + << sub_stream->DebugStreamPointers(); + } else { + // The sub_stream is not reusable, move on to the next one. + ++index; } } + + // No streams are reusable; create a new stream. sub_streams_.emplace_back(std::unique_ptr<Stream>{new Stream{parent_}}, false); Stream *sub_stream = sub_streams_.back().first.get(); sub_stream->Init(); CHECK(ok_) << "sub-stream failed to be initialized"; + VLOG(1) << DebugStreamPointers() << " created new sub_stream " + << sub_stream->DebugStreamPointers(); return sub_stream; } void Stream::ReturnSubStream(Stream *sub_stream) { mutex_lock lock(mu_); - for (auto &stream : sub_streams_) { - if (stream.first.get() == sub_stream) { - // Streams have a monotonic state machine; if a stream - // encounters an error, it will remain in an error state - // forever. Only allow re-use of ok streams. - // - // TODO(toddw): Improve this mechanism, if necessary, to drop - // failed streams completely. - const bool ready_to_reuse = sub_stream->ok(); - stream.second = ready_to_reuse; - return; + + // Look for the sub-stream. + for (int64 index = 0; index < sub_streams_.size(); ++index) { + std::pair<std::unique_ptr<Stream>, bool> &pair = sub_streams_[index]; + if (pair.first.get() != sub_stream) { + continue; } + + // Found the sub_stream. + if (sub_stream->ok()) { + VLOG(1) << DebugStreamPointers() << " returned ok sub_stream " + << sub_stream->DebugStreamPointers(); + pair.second = true; + } else { + // The returned stream is not ok. Streams have a monotonic state + // machine; the stream will remain in !ok forever. Swap it with the last + // stream and pop it off. + VLOG(1) << DebugStreamPointers() << " returned !ok sub_stream " + << sub_stream->DebugStreamPointers(); + const int64 last = sub_streams_.size() - 1; + if (index != last) { + std::swap(pair, sub_streams_[last]); + } + sub_streams_.pop_back(); + } + return; } - LOG(FATAL) << "the sub-stream to be returned is not created by this stream"; + + LOG(FATAL) << DebugStreamPointers() + << " did not create the returned sub-stream " + << sub_stream->DebugStreamPointers(); } Stream &Stream::ThenStartTimer(Timer *t) { @@ -1961,7 +2007,8 @@ Stream &Stream::ThenStartTimer(Timer *t) { if (ok()) { CheckError(parent_->StartTimer(this, t)); } else { - LOG(INFO) << "stream " << this << " did not enqueue 'start timer': " << t; + LOG(INFO) << DebugStreamPointers() + << " did not enqueue 'start timer': " << t; } return *this; } @@ -1972,7 +2019,8 @@ Stream &Stream::ThenStopTimer(Timer *t) { if (ok()) { CheckError(parent_->StopTimer(this, t)); } else { - LOG(INFO) << "stream " << this << " did not enqueue 'stop timer': " << t; + LOG(INFO) << DebugStreamPointers() + << " did not enqueue 'stop timer': " << t; } return *this; } @@ -1985,7 +2033,8 @@ Stream &Stream::ThenWaitFor(Stream *other) { CheckError(parent_->CreateStreamDependency(this, other)); } else { SetError(); - LOG(INFO) << "stream " << this << " did not wait for stream: " << other; + LOG(INFO) << DebugStreamPointers() << " did not wait for " + << other->DebugStreamPointers(); } return *this; } @@ -2002,7 +2051,7 @@ Stream &Stream::ThenWaitFor(Event *event) { << "at fault. Monitor for further errors."; } } else { - LOG(INFO) << "stream " << this << " did not wait for an event."; + LOG(INFO) << DebugStreamPointers() << " did not wait for an event."; } return *this; } @@ -4685,6 +4734,115 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( scratch_allocator); } +Stream &Stream::ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, int lda, + int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb, int64 stride_b, + float beta, DeviceMemory<Eigen::half> *c, int ldc, int64 stride_c, + int batch_count) { + VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), + PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b), + PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc), + PARAM(stride_c), PARAM(batch_count)); + + ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float, + const DeviceMemory<Eigen::half> &, int, int64, + const DeviceMemory<Eigen::half> &, int, int64, float, + DeviceMemory<Eigen::half> *, int, int64, int> + impl; + return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa, + transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_count); +} + +Stream &Stream::ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, float alpha, const DeviceMemory<float> &a, int lda, + int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b, + float beta, DeviceMemory<float> *c, int ldc, int64 stride_c, + int batch_count) { + VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), + PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b), + PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc), + PARAM(stride_c), PARAM(batch_count)); + + ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float, + const DeviceMemory<float> &, int, int64, + const DeviceMemory<float> &, int, int64, float, + DeviceMemory<float> *, int, int64, int> + impl; + return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa, + transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_count); +} + +Stream &Stream::ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, double alpha, const DeviceMemory<double> &a, int lda, + int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b, + double beta, DeviceMemory<double> *c, int ldc, int64 stride_c, + int batch_count) { + VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), + PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b), + PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc), + PARAM(stride_c), PARAM(batch_count)); + + ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, double, + const DeviceMemory<double> &, int, int64, + const DeviceMemory<double> &, int, int64, double, + DeviceMemory<double> *, int, int64, int> + impl; + return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa, + transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_count); +} + +Stream &Stream::ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, std::complex<float> alpha, + const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a, + const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b, + std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, + int64 stride_c, int batch_count) { + VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), + PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b), + PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc), + PARAM(stride_c), PARAM(batch_count)); + + ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, + std::complex<float>, const DeviceMemory<std::complex<float>> &, + int, int64, const DeviceMemory<std::complex<float>> &, int, + int64, std::complex<float>, DeviceMemory<std::complex<float>> *, + int, int64, int> + impl; + return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa, + transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_count); +} + +Stream &Stream::ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, std::complex<double> alpha, + const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a, + const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b, + std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc, + int64 stride_c, int batch_count) { + VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), + PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b), + PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc), + PARAM(stride_c), PARAM(batch_count)); + + ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, + std::complex<double>, const DeviceMemory<std::complex<double>> &, + int, int64, const DeviceMemory<std::complex<double>> &, int, + int64, std::complex<double>, + DeviceMemory<std::complex<double>> *, int, int64, int> + impl; + return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa, + transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_count); +} + Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes) { VLOG_CALL(PARAM(seed), PARAM(seed_bytes)); @@ -4693,10 +4851,10 @@ Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes) { CheckError(rng->SetSeed(this, seed, seed_bytes)); } else { SetError(); - LOG(INFO) << "stream " << this << " unable to initialize RNG"; + LOG(INFO) << DebugStreamPointers() << " unable to initialize RNG"; } } else { - LOG(INFO) << "stream " << this + LOG(INFO) << DebugStreamPointers() << " did not set RNG seed: " << static_cast<const void *>(seed) << "; bytes: " << seed_bytes; } @@ -4711,8 +4869,9 @@ Stream &Stream::ThenPopulateRandUniform(DeviceMemory<float> *values) { CheckError(rng->DoPopulateRandUniform(this, values)); } else { SetError(); - LOG(INFO) << "attempting to perform RNG operation using StreamExecutor " - "without RNG support."; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform RNG operation using StreamExecutor" + " without RNG support."; } } return *this; @@ -4727,8 +4886,9 @@ Stream &Stream::ThenPopulateRandGaussian(float mean, float sd, CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values)); } else { SetError(); - LOG(INFO) << "attempting to perform RNG operation using StreamExecutor " - "without RNG support."; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform RNG operation using StreamExecutor" + " without RNG support."; } } return *this; @@ -4743,8 +4903,9 @@ Stream &Stream::ThenPopulateRandGaussian(double mean, double sd, CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values)); } else { SetError(); - LOG(INFO) << "attempting to perform RNG operation using StreamExecutor " - "without RNG support."; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform RNG operation using StreamExecutor" + " without RNG support."; } } return *this; @@ -4758,8 +4919,9 @@ Stream &Stream::ThenPopulateRandUniform(DeviceMemory<double> *values) { CheckError(rng->DoPopulateRandUniform(this, values)); } else { SetError(); - LOG(INFO) << "attempting to perform RNG operation using StreamExecutor " - "without RNG support."; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform RNG operation using StreamExecutor" + " without RNG support."; } } return *this; @@ -4774,8 +4936,9 @@ Stream &Stream::ThenPopulateRandUniform( CheckError(rng->DoPopulateRandUniform(this, values)); } else { SetError(); - LOG(INFO) << "attempting to perform RNG operation using StreamExecutor " - "without RNG support."; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform RNG operation using StreamExecutor" + " without RNG support."; } } return *this; @@ -4790,9 +4953,9 @@ Stream &Stream::ThenPopulateRandUniform( CheckError(rng->DoPopulateRandUniform(this, values)); } else { SetError(); - LOG(INFO) << "stream " << this - << " attempting to perform RNG operation using StreamExecutor " - "without RNG support."; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform RNG operation using StreamExecutor" + " without RNG support."; } } return *this; @@ -4805,7 +4968,7 @@ Stream &Stream::ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src, if (ok()) { CheckError(parent_->Memcpy(this, host_dst, gpu_src, size)); } else { - LOG(INFO) << "stream " << this + LOG(INFO) << DebugStreamPointers() << " did not memcpy device-to-host; source: " << gpu_src.opaque(); } return *this; @@ -4818,7 +4981,7 @@ Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src, if (ok()) { CheckError(parent_->Memcpy(this, gpu_dst, host_src, size)); } else { - LOG(INFO) << "stream " << this + LOG(INFO) << DebugStreamPointers() << " did not memcpy host-to-device; source: " << host_src; } return *this; @@ -4831,7 +4994,7 @@ Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst, if (ok()) { CheckError(parent_->MemcpyDeviceToDevice(this, gpu_dst, gpu_src, size)); } else { - LOG(INFO) << "stream " << this + LOG(INFO) << DebugStreamPointers() << " did not memcpy gpu-to-gpu; source: " << &gpu_src; } return *this; @@ -4843,7 +5006,7 @@ Stream &Stream::ThenMemZero(DeviceMemoryBase *location, uint64 size) { if (ok()) { CheckError(parent_->MemZero(this, location, size)); } else { - LOG(INFO) << "stream " << this + LOG(INFO) << DebugStreamPointers() << " did not memzero GPU location; source: " << location; } return *this; @@ -4856,7 +5019,7 @@ Stream &Stream::ThenMemset32(DeviceMemoryBase *location, uint32 pattern, if (ok()) { CheckError(parent_->Memset32(this, location, pattern, size)); } else { - LOG(INFO) << "stream " << this + LOG(INFO) << DebugStreamPointers() << " did not memset GPU location; source: " << location << "; size: " << size << "; pattern: " << std::hex << pattern; } @@ -5125,12 +5288,25 @@ Stream &Stream::ThenDoHostCallback(std::function<void()> callback) { if (ok()) { CheckError(parent_->HostCallback(this, callback)); } else { - LOG(INFO) << "stream " << this + LOG(INFO) << DebugStreamPointers() << " was in error state before adding host callback"; } return *this; } +Stream &Stream::ThenDoHostCallbackWithStatus( + std::function<port::Status()> callback) { + VLOG_CALL(PARAM(callback)); + + if (ok()) { + CheckError(parent_->HostCallback(this, std::move(callback))); + } else { + LOG(WARNING) << "stream " << DebugStreamPointers() + << " was in error state before adding host callback"; + } + return *this; +} + Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<std::complex<float>> &input, DeviceMemory<std::complex<float>> *output) { @@ -5141,8 +5317,9 @@ Stream &Stream::ThenFft(fft::Plan *plan, CheckError(fft->DoFft(this, plan, input, output)); } else { SetError(); - LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " - "without FFT support"; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform FFT operation using StreamExecutor" + " without FFT support"; } } return *this; @@ -5158,8 +5335,9 @@ Stream &Stream::ThenFft(fft::Plan *plan, CheckError(fft->DoFft(this, plan, input, output)); } else { SetError(); - LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " - "without FFT support"; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform FFT operation using StreamExecutor" + " without FFT support"; } } return *this; @@ -5174,8 +5352,9 @@ Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<float> &input, CheckError(fft->DoFft(this, plan, input, output)); } else { SetError(); - LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " - "without FFT support"; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform FFT operation using StreamExecutor" + " without FFT support"; } } return *this; @@ -5190,8 +5369,9 @@ Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<double> &input, CheckError(fft->DoFft(this, plan, input, output)); } else { SetError(); - LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " - "without FFT support"; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform FFT operation using StreamExecutor" + " without FFT support"; } } return *this; @@ -5207,8 +5387,9 @@ Stream &Stream::ThenFft(fft::Plan *plan, CheckError(fft->DoFft(this, plan, input, output)); } else { SetError(); - LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " - "without FFT support"; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform FFT operation using StreamExecutor" + " without FFT support"; } } return *this; @@ -5224,8 +5405,9 @@ Stream &Stream::ThenFft(fft::Plan *plan, CheckError(fft->DoFft(this, plan, input, output)); } else { SetError(); - LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " - "without FFT support"; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform FFT operation using StreamExecutor" + " without FFT support"; } } return *this; @@ -5252,7 +5434,7 @@ port::Status Stream::BlockHostUntilDone() { port::Status status = port::Status( port::error::INTERNAL, "stream did not block host until done; was already in an error state"); - LOG(INFO) << status << " " << this; + LOG(INFO) << DebugStreamPointers() << " " << status; return status; } @@ -5263,4 +5445,10 @@ port::Status Stream::BlockHostUntilDone() { return error; } +string Stream::DebugStreamPointers() const { + // Relies on the ToVlogString(const void*) overload above. + return port::StrCat("[stream=", ToVlogString(this), + ",impl=", ToVlogString(implementation_.get()), "]"); +} + } // namespace stream_executor diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h index 706442a666..e1629b5b30 100644 --- a/tensorflow/stream_executor/stream.h +++ b/tensorflow/stream_executor/stream.h @@ -122,10 +122,14 @@ class Stream { // Get or create a sub-stream from this stream. If there is any sub-stream in // the pool that can be reused then just return this sub-stream. Otherwise // create a new sub-stream. + // + // TODO(b/112196569): The semantics of failed sub-streams is error-prone. Stream *GetOrCreateSubStream() LOCKS_EXCLUDED(mu_); // Return the sub-stream back to the host stream so that it can be reused // later. Sub-streams that are !ok() will not be reused. + // + // TODO(b/112196569): The semantics of failed sub-streams is error-prone. void ReturnSubStream(Stream *sub_stream) LOCKS_EXCLUDED(mu_); // Allocate temporary memories. The stream will deallocate them when blocked @@ -1557,6 +1561,38 @@ class Stream { std::complex<double> beta, const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc, int batch_count, ScratchAllocator *scratch_allocator); + Stream &ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, int lda, + int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb, + int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc, + int64 stride_c, int batch_count); + Stream &ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, float alpha, const DeviceMemory<float> &a, int lda, + int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b, + float beta, DeviceMemory<float> *c, int ldc, int64 stride_c, + int batch_count); + Stream &ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, double alpha, const DeviceMemory<double> &a, int lda, + int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b, + double beta, DeviceMemory<double> *c, int ldc, int64 stride_c, + int batch_count); + Stream &ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, std::complex<float> alpha, + const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a, + const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b, + std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, + int64 stride_c, int batch_count); + Stream &ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, std::complex<double> alpha, + const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a, + const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b, + std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc, + int64 stride_c, int batch_count); // See BlasSupport::DoBlasHemm. Stream &ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m, @@ -2009,6 +2045,11 @@ class Stream { // negative effects on performance. Stream &ThenDoHostCallback(std::function<void()> callback); + // Entrains onto the stream a callback to the host (from the device). + // Behaves as ThenDoHostCallback above, but returns a Status instead of void. + // This overload should be preferred if the callback could fail. + Stream &ThenDoHostCallbackWithStatus(std::function<port::Status()> callback); + // Returns the StreamExecutor (parent object) associated with this stream. StreamExecutor *parent() const { CHECK(parent_ != nullptr); @@ -2019,6 +2060,9 @@ class Stream { // with this stream. internal::TemporaryMemoryManager *temporary_memory_manager(); + // Returns a debugging string "[stream=0x...,impl=0x...]". + string DebugStreamPointers() const; + private: friend class host::HostBlas; // for parent_. friend class host::HostFft; // for parent_. diff --git a/tensorflow/stream_executor/stream_executor_internal.cc b/tensorflow/stream_executor/stream_executor_internal.cc index 8297228e6f..7df6a361c6 100644 --- a/tensorflow/stream_executor/stream_executor_internal.cc +++ b/tensorflow/stream_executor/stream_executor_internal.cc @@ -36,5 +36,17 @@ StreamExecutorFactory* MakeOpenCLExecutorImplementation() { StreamExecutorFactory MakeHostExecutorImplementation; +// TODO(b/112125301): Consolodate this down to one implementation of +// HostCallback, taking a callback that returns a Status. +bool StreamExecutorInterface::HostCallback( + Stream* stream, std::function<port::Status()> callback) { + return HostCallback(stream, [callback]() { + port::Status s = callback(); + if (!s.ok()) { + LOG(WARNING) << "HostCallback failed: " << s; + } + }); +} + } // namespace internal } // namespace stream_executor diff --git a/tensorflow/stream_executor/stream_executor_internal.h b/tensorflow/stream_executor/stream_executor_internal.h index 52b3dc04c4..59a477b5c9 100644 --- a/tensorflow/stream_executor/stream_executor_internal.h +++ b/tensorflow/stream_executor/stream_executor_internal.h @@ -239,6 +239,8 @@ class StreamExecutorInterface { const DeviceMemoryBase &gpu_src, uint64 size) = 0; virtual bool HostCallback(Stream *stream, std::function<void()> callback) = 0; + virtual bool HostCallback(Stream *stream, + std::function<port::Status()> callback); virtual port::Status AllocateEvent(Event *event) = 0; virtual port::Status DeallocateEvent(Event *event) = 0; virtual port::Status RecordEvent(Stream *stream, Event *event) = 0; diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc index 2e0137a485..9515d8e62a 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.cc +++ b/tensorflow/stream_executor/stream_executor_pimpl.cc @@ -699,6 +699,11 @@ bool StreamExecutor::HostCallback(Stream *stream, return implementation_->HostCallback(stream, std::move(callback)); } +bool StreamExecutor::HostCallback(Stream *stream, + std::function<port::Status()> callback) { + return implementation_->HostCallback(stream, std::move(callback)); +} + port::Status StreamExecutor::AllocateEvent(Event *event) { return implementation_->AllocateEvent(event); } diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h index 47b3a2b030..437f298616 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.h +++ b/tensorflow/stream_executor/stream_executor_pimpl.h @@ -549,6 +549,11 @@ class StreamExecutor { // See Stream::ThenDoHostCallback for full details. bool HostCallback(Stream *stream, std::function<void()> callback); + // Entrains on a stream a user-specified function to be run on the host. + // See Stream::ThenDoHostCallback for full details. + // This is the preferred form for a callback that may return an error. + bool HostCallback(Stream *stream, std::function<port::Status()> callback); + // Performs platform-specific allocation and initialization of an event. port::Status AllocateEvent(Event *event); diff --git a/tensorflow/stream_executor/stream_test.cc b/tensorflow/stream_executor/stream_test.cc index 47dd675834..cfc051fd09 100644 --- a/tensorflow/stream_executor/stream_test.cc +++ b/tensorflow/stream_executor/stream_test.cc @@ -95,18 +95,18 @@ TEST_F(StreamTest, TwoSubStreams) { EXPECT_NE(sub_stream3, sub_stream4); } -TEST_F(StreamTest, FailedSubStreamNotReused) { +TEST_F(StreamTest, FailedSubStreamBeforeReturnNotReused) { std::unique_ptr<StreamExecutor> executor = NewStreamExecutor(); Stream stream(executor.get()); stream.Init(); EXPECT_TRUE(stream.ok()); - // Get a sub-stream. + // Get sub_stream1. Stream* sub_stream1 = stream.GetOrCreateSubStream(); EXPECT_TRUE(sub_stream1->ok()); - // Force an error on the stream; here we call a method that requires - // DNN support, which we know the Host platform doesn't support. + // Force an error on sub_stream1; here we call a method that requires DNN + // support, which we know the Host platform doesn't support. sub_stream1->ThenDepthConcatenate({}, {}, nullptr); EXPECT_FALSE(sub_stream1->ok()); @@ -115,20 +115,84 @@ TEST_F(StreamTest, FailedSubStreamNotReused) { Stream* sub_stream2 = stream.GetOrCreateSubStream(); EXPECT_TRUE(sub_stream2->ok()); - // The underlying streams should be different. They would have been - // the same, but since we forced an error on sub_stream1, it will - // not be re-used. Sadly we can't just check: + // The underlying sub_streams should be different. They would have been the + // same, but since we forced an error on sub_stream1, it will not be + // re-used. Sadly we can't just check: // EXPECT_NE(sub_stream1, sub_stream2); // - // The above should hold logically, but it may fail if the new - // stream instance allocated for sub_stream2 happens to reside in - // the same memory address as sub_stream1. + // The above should hold logically, but it may fail if the new Stream instance + // allocated for sub_stream2 happens to reside in the same memory address as + // sub_stream1. // // The check that sub_stream2->ok() serves as a good-enough check. - // Return sub_stream2 and get sub_stream3. The previous error on - // sub_stream1 has no effect on these streams, and they are the - // same. + // Return sub_stream2 and get sub_stream3. The previous error on sub_stream1 + // has no effect on these streams, and they are the same. + stream.ReturnSubStream(sub_stream2); + Stream* sub_stream3 = stream.GetOrCreateSubStream(); + EXPECT_TRUE(sub_stream3->ok()); + EXPECT_EQ(sub_stream2, sub_stream3); +} + +TEST_F(StreamTest, FailedSubStreamAfterReturnNotReused) { + std::unique_ptr<StreamExecutor> executor = NewStreamExecutor(); + Stream stream(executor.get()); + stream.Init(); + EXPECT_TRUE(stream.ok()); + + // Get and return sub_stream1. + Stream* sub_stream1 = stream.GetOrCreateSubStream(); + EXPECT_TRUE(sub_stream1->ok()); + stream.ReturnSubStream(sub_stream1); + + // Force an error on sub_stream1; here we call a method that requires DNN + // support, which we know the Host platform doesn't support. + // + // It is a bit weird to use sub_stream1 after it has already been returned. By + // doing this, we're simulating an asynchronous error that occurs during + // execution of the sub_stream, that occurs after the sub_stream is returned. + // + // E.g. the following is a common pattern of usage, where the execution of the + // operations enqueued onto the sub streams may occur after the streams have + // already been returned. + // + // void EnqueueOnSubStreams(Stream* stream) { + // Stream* sub_stream1 = stream.GetOrCreateSubStream(); + // Stream* sub_stream2 = stream.GetOrCreateSubStream(); + // // ... enqueue some operations on the sub streams ... + // stream.ThenWaitFor(sub_stream1).ThenWaitFor(sub_stream2); + // stream.ReturnSubStream(sub_stream1); + // stream.ReturnSubStream(sub_stream2); + // } + // + // Stream* main_stream = ...; + // EnqueueOnSubStreams(main_stream); + // main_stream.BlockHostUntilDone(); + // + // TODO(b/112196569): The semantics of failed sub-streams is error-prone; + // GetOrCreateSubStream can still return a sub-stream that has not encountered + // an error yet, but will encounter one in the future, based on previously + // enqueued operations. + sub_stream1->ThenDepthConcatenate({}, {}, nullptr); + EXPECT_FALSE(sub_stream1->ok()); + + // Get and return sub_stream2. + Stream* sub_stream2 = stream.GetOrCreateSubStream(); + EXPECT_TRUE(sub_stream2->ok()); + + // The underlying streams should be different. They would have been the same, + // but since we forced an error on sub_stream1, it will not be re-used. Sadly + // we can't just check: + // EXPECT_NE(sub_stream1, sub_stream2); + // + // The above should hold logically, but it may fail if the new stream instance + // allocated for sub_stream2 happens to reside in the same memory address as + // sub_stream1. + // + // The check that sub_stream2->ok() serves as a good-enough check. + + // Return sub_stream2 and get sub_stream3. The previous error on sub_stream1 + // has no effect on these streams, and they are the same. stream.ReturnSubStream(sub_stream2); Stream* sub_stream3 = stream.GetOrCreateSubStream(); EXPECT_TRUE(sub_stream3->ok()); |