diff options
Diffstat (limited to 'tensorflow/stream_executor')
22 files changed, 1344 insertions, 384 deletions
diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h index ea87744b22..7f851e3646 100644 --- a/tensorflow/stream_executor/blas.h +++ b/tensorflow/stream_executor/blas.h @@ -1121,6 +1121,40 @@ class BlasSupport { const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc, int batch_count, ScratchAllocator *scratch_allocator) = 0; + // Batched gemm with strides instead of pointer arrays. + virtual bool DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, + int lda, int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb, + int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc, + int64 stride_c, int batch_count) = 0; + virtual bool DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda, + int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b, + float beta, DeviceMemory<float> *c, int ldc, int64 stride_c, + int batch_count) = 0; + virtual bool DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda, + int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b, + double beta, DeviceMemory<double> *c, int ldc, int64 stride_c, + int batch_count) = 0; + virtual bool DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, std::complex<float> alpha, + const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a, + const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b, + std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, + int64 stride_c, int batch_count) = 0; + virtual bool DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, std::complex<double> alpha, + const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a, + const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b, + std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc, + int64 stride_c, int batch_count) = 0; + // Computes a matrix-matrix product where one input matrix is Hermitian: // // c <- alpha * a * b + beta * c, @@ -1990,6 +2024,38 @@ class BlasSupport { int ldb, std::complex<double> beta, \ const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, \ int ldc, int batch_count, ScratchAllocator *scratch_allocator) override; \ + bool DoBlasGemmStridedBatched( \ + Stream *stream, blas::Transpose transa, blas::Transpose transb, \ + uint64 m, uint64 n, uint64 k, float alpha, \ + const DeviceMemory<Eigen::half> &a, int lda, int64 stride_a, \ + const DeviceMemory<Eigen::half> &b, int ldb, int64 stride_b, float beta, \ + DeviceMemory<Eigen::half> *c, int ldc, int64 stride_c, int batch_count); \ + bool DoBlasGemmStridedBatched( \ + Stream *stream, blas::Transpose transa, blas::Transpose transb, \ + uint64 m, uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, \ + int lda, int64 stride_a, const DeviceMemory<float> &b, int ldb, \ + int64 stride_b, float beta, DeviceMemory<float> *c, int ldc, \ + int64 stride_c, int batch_count); \ + bool DoBlasGemmStridedBatched( \ + Stream *stream, blas::Transpose transa, blas::Transpose transb, \ + uint64 m, uint64 n, uint64 k, double alpha, \ + const DeviceMemory<double> &a, int lda, int64 stride_a, \ + const DeviceMemory<double> &b, int ldb, int64 stride_b, double beta, \ + DeviceMemory<double> *c, int ldc, int64 stride_c, int batch_count); \ + bool DoBlasGemmStridedBatched( \ + Stream *stream, blas::Transpose transa, blas::Transpose transb, \ + uint64 m, uint64 n, uint64 k, std::complex<float> alpha, \ + const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a, \ + const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b, \ + std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, \ + int64 stride_c, int batch_count); \ + bool DoBlasGemmStridedBatched( \ + Stream *stream, blas::Transpose transa, blas::Transpose transb, \ + uint64 m, uint64 n, uint64 k, std::complex<double> alpha, \ + const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a, \ + const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b, \ + std::complex<double> beta, DeviceMemory<std::complex<double>> *c, \ + int ldc, int64 stride_c, int batch_count); \ bool DoBlasHemm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ uint64 m, uint64 n, std::complex<float> alpha, \ const DeviceMemory<std::complex<float>> &a, int lda, \ diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc index 874bf0e8cb..ab7091b3f5 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.cc +++ b/tensorflow/stream_executor/cuda/cuda_blas.cc @@ -279,6 +279,10 @@ STREAM_EXECUTOR_CUBLAS_WRAP(cublasSgemmEx) #if CUDA_VERSION >= 8000 STREAM_EXECUTOR_CUBLAS_WRAP(cublasGemmEx) +STREAM_EXECUTOR_CUBLAS_WRAP(cublasSgemmStridedBatched) +STREAM_EXECUTOR_CUBLAS_WRAP(cublasDgemmStridedBatched) +STREAM_EXECUTOR_CUBLAS_WRAP(cublasCgemmStridedBatched) +STREAM_EXECUTOR_CUBLAS_WRAP(cublasZgemmStridedBatched) #endif #if CUDA_VERSION >= 9000 @@ -288,6 +292,7 @@ STREAM_EXECUTOR_CUBLAS_WRAP(cublasSetMathMode) #if CUDA_VERSION >= 9010 STREAM_EXECUTOR_CUBLAS_WRAP(cublasGemmBatchedEx) +STREAM_EXECUTOR_CUBLAS_WRAP(cublasGemmStridedBatchedEx) #endif } // namespace wrap @@ -643,7 +648,7 @@ bool CUDABlas::DoBlasInternalImpl(FuncT cublas_func, Stream *stream, } #endif cublasStatus_t ret = cublas_func(parent_, blas_, args...); - if (err_on_failure && ret != CUBLAS_STATUS_SUCCESS) { + if ((err_on_failure || VLOG_IS_ON(3)) && ret != CUBLAS_STATUS_SUCCESS) { LOG(ERROR) << "failed to run cuBLAS routine " << cublas_func.kName << ": " << ToString(ret); } @@ -1865,7 +1870,7 @@ bool CUDABlas::DoBlasGemm( stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor); - // GPUs < sm_70 don't support Volta hardware. + // GPUs < sm_70 don't support tensor ops. if (cc_major >= 7 && TensorOpMathEnabled()) { use_tensor_ops = true; } @@ -2139,6 +2144,10 @@ static bool UsesTensorOps(blas::AlgorithmType algo) { template <typename InType> static bool TensorOpsAvailable(int cc_major) { #if CUDA_VERSION >= 9000 + // cublas *does* allow tensor ops on inputs that are not fp16, so this is not + // strictly correct. We can't simply enable it, though, as that would change + // clients' behavior significantly: Using tensor ops on fp32 inputs cause them + // to be rounded to fp16. if (cc_major >= 7 && TensorOpMathEnabled() && std::is_same<InType, Eigen::half>::value) { return true; @@ -2160,16 +2169,30 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl( if (stream->parent()->GetDeviceDescription().cuda_compute_capability( &cc_major, &cc_minor) && cc_major < 5) { + VLOG(2) << "DoBlasGemmWithAlgorithm returning false because sm" << cc_major + << cc_minor << " devices don't support explicit gemm algorithms."; return false; } if (UsesTensorOps(algorithm) && !TensorOpsAvailable<InT>(cc_major)) { + if (std::is_same<InT, Eigen::half>::value) { + VLOG(2) << "DoBlasGemmWithAlgorithm returning false because algorithm " + << algorithm + << " uses tensor ops, but tensor ops are not available in sm" + << cc_major << "X devices."; + } else { + VLOG(2) << "DoBlasGemmWithAlgorithm returning false because algorithm " + << algorithm + << " uses tensor ops, but the input data type is not fp16."; + } return false; } // Either both 'alpha' and 'beta' need to be pointers to device memory, or // they need to be both host scalars. if (alpha.is_pointer() != beta.is_pointer()) { + VLOG(2) << "DoBlasGemmWithAlgorithm returning false because one of `alpha` " + "and `beta` is a pointer, but the other is not."; return false; } @@ -2177,6 +2200,9 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl( if (output_profile_result != nullptr) { timer.reset(new CUDATimer(parent_)); if (!timer->Init() || !timer->Start(AsCUDAStream(stream))) { + VLOG(2) << "DoBlasGemmWithAlgorithm returning false because " + "output_profile_result was given, but we were unable to " + "create a CUDATimer."; return false; } } @@ -2186,6 +2212,8 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl( #if CUDA_VERSION >= 9000 && CUDA_VERSION < 9020 if ((algorithm == CUBLAS_GEMM_DEFAULT || algorithm >= CUBLAS_GEMM_ALGO13) && std::max({m, n, k}) >= 2097153 && cc_major < 7) { + VLOG(2) << "DoBlasGemmWithAlgorithm returning false to work around cudnn " + "<9.2 bug with m, n, or k >= 2097153. See b/79126339."; return false; } #endif @@ -2211,6 +2239,8 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl( // CUDATimer will CHECK-fail if we Stop() it while the stream is in an error // state. if (!timer->Stop(AsCUDAStream(stream))) { + VLOG(2) << "DoBlasGemmWithAlgorithm returning false; unable to stop " + "CUDATimer."; return false; } output_profile_result->set_is_valid(true); @@ -2223,26 +2253,60 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl( bool CUDABlas::GetBlasGemmAlgorithms( std::vector<blas::AlgorithmType> *out_algorithms) { -// cublasGemmAlgo_t (and the function that accepts this type, cublasGemmEx) -// were first introduced in CUDA 8. -// Note that when CUDA version and compute capability is not sufficient, we -// still return the out_algorithms. Caller needs to make sure that in this case, -// the returned vector is empty. - for (cublasGemmAlgo_t algo : { - CUBLAS_GEMM_DFALT, CUBLAS_GEMM_ALGO0, CUBLAS_GEMM_ALGO1, - CUBLAS_GEMM_ALGO2, CUBLAS_GEMM_ALGO3, CUBLAS_GEMM_ALGO4, - CUBLAS_GEMM_ALGO5, CUBLAS_GEMM_ALGO6, CUBLAS_GEMM_ALGO7, + // cublasGemmAlgo_t (and the function that accepts this type, cublasGemmEx) + // were first introduced in CUDA 8. + // + // Note that when CUDA version and compute capability is not sufficient, we + // still return the out_algorithms. Caller needs to make sure that in this + // case, the returned vector is empty. + *out_algorithms = { + CUBLAS_GEMM_DFALT, + CUBLAS_GEMM_ALGO0, + CUBLAS_GEMM_ALGO1, + CUBLAS_GEMM_ALGO2, + CUBLAS_GEMM_ALGO3, + CUBLAS_GEMM_ALGO4, + CUBLAS_GEMM_ALGO5, + CUBLAS_GEMM_ALGO6, + CUBLAS_GEMM_ALGO7, #if CUDA_VERSION >= 9000 - CUBLAS_GEMM_ALGO8, CUBLAS_GEMM_ALGO9, CUBLAS_GEMM_ALGO10, - CUBLAS_GEMM_ALGO11, CUBLAS_GEMM_ALGO12, CUBLAS_GEMM_ALGO13, - CUBLAS_GEMM_ALGO14, CUBLAS_GEMM_ALGO15, CUBLAS_GEMM_ALGO16, - CUBLAS_GEMM_ALGO17, CUBLAS_GEMM_DFALT_TENSOR_OP, - CUBLAS_GEMM_ALGO0_TENSOR_OP, CUBLAS_GEMM_ALGO1_TENSOR_OP, - CUBLAS_GEMM_ALGO2_TENSOR_OP + CUBLAS_GEMM_ALGO8, + CUBLAS_GEMM_ALGO9, + CUBLAS_GEMM_ALGO10, + CUBLAS_GEMM_ALGO11, + CUBLAS_GEMM_ALGO12, + CUBLAS_GEMM_ALGO13, + CUBLAS_GEMM_ALGO14, + CUBLAS_GEMM_ALGO15, + CUBLAS_GEMM_ALGO16, + CUBLAS_GEMM_ALGO17, + CUBLAS_GEMM_DFALT_TENSOR_OP, + CUBLAS_GEMM_ALGO0_TENSOR_OP, + CUBLAS_GEMM_ALGO1_TENSOR_OP, + CUBLAS_GEMM_ALGO2_TENSOR_OP, + CUBLAS_GEMM_ALGO3_TENSOR_OP, + CUBLAS_GEMM_ALGO4_TENSOR_OP, #endif - }) { - out_algorithms->push_back(algo); - } +#if CUDA_VERSION >= 9200 + CUBLAS_GEMM_ALGO18, + CUBLAS_GEMM_ALGO19, + CUBLAS_GEMM_ALGO20, + CUBLAS_GEMM_ALGO21, + CUBLAS_GEMM_ALGO22, + CUBLAS_GEMM_ALGO23, + CUBLAS_GEMM_ALGO5_TENSOR_OP, + CUBLAS_GEMM_ALGO6_TENSOR_OP, + CUBLAS_GEMM_ALGO7_TENSOR_OP, + CUBLAS_GEMM_ALGO8_TENSOR_OP, + CUBLAS_GEMM_ALGO9_TENSOR_OP, + CUBLAS_GEMM_ALGO10_TENSOR_OP, + CUBLAS_GEMM_ALGO11_TENSOR_OP, + CUBLAS_GEMM_ALGO12_TENSOR_OP, + CUBLAS_GEMM_ALGO13_TENSOR_OP, + CUBLAS_GEMM_ALGO14_TENSOR_OP, + CUBLAS_GEMM_ALGO15_TENSOR_OP, +#endif + }; return true; } @@ -2564,6 +2628,119 @@ bool CUDABlas::DoBlasGemmBatched( return status.ok(); } +bool CUDABlas::DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, + int lda, int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb, + int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc, + int64 stride_c, int batch_count) { + bool use_tensor_ops = false; +#if CUDA_VERSION >= 9000 + int cc_major, cc_minor; + if (stream->parent()->GetDeviceDescription().cuda_compute_capability( + &cc_major, &cc_minor)) { + // GPUs < sm_70 don't support tensor ops. + if (cc_major >= 7 && TensorOpMathEnabled()) { + use_tensor_ops = true; + } +#if CUDA_VERSION >= 9010 + if (cc_major >= 5) { + cublasGemmAlgo_t algo = + (use_tensor_ops ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DFALT); + bool ok = DoBlasInternalImpl( + wrap::cublasGemmStridedBatchedEx, stream, + true /* = pointer_mode_host */, true /* = err_on_failure */, + use_tensor_ops, CUDABlasTranspose(transa), CUDABlasTranspose(transb), + m, n, k, &alpha, CUDAMemory(a), CUDA_R_16F, lda, stride_a, + CUDAMemory(b), CUDA_R_16F, ldb, stride_b, &beta, CUDAMemoryMutable(c), + CUDA_R_16F, ldc, stride_c, batch_count, CUDA_R_32F, algo); + if (ok) { + return true; + } + LOG(ERROR) << "failed BLAS call, see log for details"; + return false; + } +#endif + } +#endif + // Either CUDA_VERSION < 9.1 or SM < 5.0. Fall back to a loop. + for (int batch = 0; batch < batch_count; ++batch) { + const auto *a_matrix = + reinterpret_cast<const __half *>(CUDAMemory(a) + batch * stride_a); + const auto *b_matrix = + reinterpret_cast<const __half *>(CUDAMemory(b) + batch * stride_b); + auto *c_matrix = + reinterpret_cast<__half *>(CUDAMemoryMutable(c) + batch * stride_c); + bool ok = DoBlasInternalImpl( + wrap::cublasSgemmEx, stream, true /* = pointer_mode_host */, + true /* = err_on_failure= */, use_tensor_ops, CUDABlasTranspose(transa), + CUDABlasTranspose(transb), m, n, k, &alpha, a_matrix, SE_CUDA_DATA_HALF, + lda, b_matrix, SE_CUDA_DATA_HALF, ldb, &beta, c_matrix, + SE_CUDA_DATA_HALF, ldc); + if (!ok) { + LOG(ERROR) << "failed BLAS call, see log for details"; + return false; + } + } + return true; +} + +bool CUDABlas::DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda, + int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b, + float beta, DeviceMemory<float> *c, int ldc, int64 stride_c, + int batch_count) { + return DoBlasInternal( + wrap::cublasSgemmStridedBatched, stream, true /* = pointer_mode_host */, + CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha, + CUDAMemory(a), lda, stride_a, CUDAMemory(b), ldb, stride_b, &beta, + CUDAMemoryMutable(c), ldc, stride_c, batch_count); +} + +bool CUDABlas::DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda, + int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b, + double beta, DeviceMemory<double> *c, int ldc, int64 stride_c, + int batch_count) { + return DoBlasInternal( + wrap::cublasDgemmStridedBatched, stream, true /* = pointer_mode_host */, + CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha, + CUDAMemory(a), lda, stride_a, CUDAMemory(b), ldb, stride_b, &beta, + CUDAMemoryMutable(c), ldc, stride_c, batch_count); +} + +bool CUDABlas::DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, std::complex<float> alpha, + const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a, + const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b, + std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, + int64 stride_c, int batch_count) { + return DoBlasInternal( + wrap::cublasCgemmStridedBatched, stream, true /* = pointer_mode_host */, + CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, + CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda, stride_a, + CUDAComplex(CUDAMemory(b)), ldb, stride_b, CUDAComplex(&beta), + CUDAComplex(CUDAMemoryMutable(c)), ldc, stride_c, batch_count); +} + +bool CUDABlas::DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, std::complex<double> alpha, + const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a, + const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b, + std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc, + int64 stride_c, int batch_count) { + return DoBlasInternal( + wrap::cublasZgemmStridedBatched, stream, true /* = pointer_mode_host */, + CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, + CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda, stride_a, + CUDAComplex(CUDAMemory(b)), ldb, stride_b, CUDAComplex(&beta), + CUDAComplex(CUDAMemoryMutable(c)), ldc, stride_c, batch_count); +} + bool CUDABlas::DoBlasHemm(Stream *stream, blas::Side side, blas::UpperLower uplo, uint64 m, uint64 n, std::complex<float> alpha, diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index 84916385a8..55408ab9ab 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -322,6 +322,7 @@ port::Status GetLoadedCudnnVersion(CudnnVersion* version) { CudnnSupport::CudnnSupport(CUDAExecutor* parent) : parent_(parent) {} port::Status CudnnSupport::Init() { + ScopedActivateExecutorContext context(parent_); cudnnHandle_t cudnn_handle = nullptr; auto status = cudnnCreate(&cudnn_handle); if (status == CUDNN_STATUS_SUCCESS) { @@ -791,6 +792,11 @@ class CudnnActivationDescriptor { double relu_ceiling = 0.0; cudnnActivationMode_t mode; switch (activation_mode) { +#if CUDNN_VERSION >= 7100 + case dnn::ActivationMode::kNone: + mode = CUDNN_ACTIVATION_IDENTITY; + break; +#endif case dnn::ActivationMode::kRelu6: relu_ceiling = 6.0; mode = CUDNN_ACTIVATION_CLIPPED_RELU; @@ -1980,15 +1986,14 @@ GetCudnnConvolutionBackwardFilterAlgo(const CudnnHandle& cudnn, port::StatusOr<DeviceMemory<uint8>> AllocateCudnnConvolutionForwardWorkspace( Stream* stream, const CudnnHandle& cudnn, - const dnn::AlgorithmDesc& algorithm_desc, const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter, const CudnnConvolutionDescriptor& conv, - const CudnnTensorDescriptor& output_nd, + const CudnnTensorDescriptor& output_nd, dnn::AlgorithmDesc* algorithm_desc, ScratchAllocator* scratch_allocator) { // TODO(csigg): This has side effects on the convolution descriptor. It is // functionally correct because the convolution is run with the algorithm of // the last call to this function, but should be fixed anyway. - conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled()); + conv.set_use_tensor_op_math(algorithm_desc->tensor_ops_enabled()); // Query the size of the workspace and allocate it. size_t size_in_bytes; @@ -1996,8 +2001,14 @@ port::StatusOr<DeviceMemory<uint8>> AllocateCudnnConvolutionForwardWorkspace( cudnn.handle(), /*xDesc=*/input_nd.handle(), /*wDesc=*/filter.handle(), /*convDesc=*/conv.handle(), - /*yDesc=*/output_nd.handle(), /*algo=*/ToConvForwardAlgo(algorithm_desc), + /*yDesc=*/output_nd.handle(), /*algo=*/ToConvForwardAlgo(*algorithm_desc), /*sizeInBytes=*/&size_in_bytes)); + + if (TF_PREDICT_FALSE(!algorithm_desc)) { + return port::Status(port::error::INVALID_ARGUMENT, + "No AlgorithmDesc provided"); + } + algorithm_desc->set_scratch_size(size_in_bytes); int64 size_in_bytes_int64 = size_in_bytes; if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) { @@ -2022,15 +2033,14 @@ port::StatusOr<DeviceMemory<uint8>> AllocateCudnnConvolutionForwardWorkspace( port::StatusOr<DeviceMemory<uint8>> AllocateCudnnConvolutionBackwardDataWorkspace( Stream* stream, const CudnnHandle& cudnn, - const dnn::AlgorithmDesc& algorithm_desc, const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter, const CudnnConvolutionDescriptor& conv, - const CudnnTensorDescriptor& output_nd, + const CudnnTensorDescriptor& output_nd, dnn::AlgorithmDesc* algorithm_desc, ScratchAllocator* scratch_allocator) { // TODO(csigg): This has side effects on the convolution descriptor. It is // functionally correct because the convolution is run with the algorithm of // the last call to this function, but should be fixed anyway. - conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled()); + conv.set_use_tensor_op_math(algorithm_desc->tensor_ops_enabled()); // Query the size of the workspace and allocate it. size_t size_in_bytes; @@ -2040,8 +2050,14 @@ AllocateCudnnConvolutionBackwardDataWorkspace( /*dyDesc=*/output_nd.handle(), /*convDesc=*/conv.handle(), /*dxDesc=*/input_nd.handle(), - /*algo=*/ToConvBackwardDataAlgo(algorithm_desc), + /*algo=*/ToConvBackwardDataAlgo(*algorithm_desc), /*sizeInBytes=*/&size_in_bytes)); + + if (TF_PREDICT_FALSE(!algorithm_desc)) { + return port::Status(port::error::INVALID_ARGUMENT, + "No AlgorithmDesc provided"); + } + algorithm_desc->set_scratch_size(size_in_bytes); int64 size_in_bytes_int64 = size_in_bytes; if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) { @@ -2066,15 +2082,14 @@ AllocateCudnnConvolutionBackwardDataWorkspace( port::StatusOr<DeviceMemory<uint8>> AllocateCudnnConvolutionBackwardFilterWorkspace( Stream* stream, const CudnnHandle& cudnn, - const dnn::AlgorithmDesc& algorithm_desc, const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter, const CudnnConvolutionDescriptor& conv, - const CudnnTensorDescriptor& output_nd, + const CudnnTensorDescriptor& output_nd, dnn::AlgorithmDesc* algorithm_desc, ScratchAllocator* scratch_allocator) { // TODO(csigg): This has side effects on the convolution descriptor. It is // functionally correct because the convolution is run with the algorithm of // the last call to this function, but should be fixed anyway. - conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled()); + conv.set_use_tensor_op_math(algorithm_desc->tensor_ops_enabled()); // Query the size of the workspace and allocate it. size_t size_in_bytes; @@ -2084,8 +2099,14 @@ AllocateCudnnConvolutionBackwardFilterWorkspace( /*dyDesc=*/output_nd.handle(), /*convDesc=*/conv.handle(), /*gradDesc=*/filter.handle(), - /*algo=*/ToConvBackwardFilterAlgo(algorithm_desc), + /*algo=*/ToConvBackwardFilterAlgo(*algorithm_desc), /*sizeInBytes=*/&size_in_bytes)); + + if (TF_PREDICT_FALSE(!algorithm_desc)) { + return port::Status(port::error::INVALID_ARGUMENT, + "No AlgorithmDesc provided"); + } + algorithm_desc->set_scratch_size(size_in_bytes); int64 size_in_bytes_int64 = size_in_bytes; if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) { @@ -2132,7 +2153,7 @@ port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionForwardAlgorithm( } auto scratch_or = AllocateCudnnConvolutionForwardWorkspace( - stream, cudnn, algo_desc, input_nd, filter, conv, output_nd, + stream, cudnn, input_nd, filter, conv, output_nd, &algo_desc, scratch_allocator); if (scratch_or.ok()) { @@ -2149,11 +2170,11 @@ port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionForwardAlgorithm( "while a secondary algorithm is not provided."); } - SE_ASSIGN_OR_RETURN( - *scratch, AllocateCudnnConvolutionForwardWorkspace( - stream, cudnn, algorithm_config.algorithm_no_scratch(), - input_nd, filter, conv, output_nd, scratch_allocator)); - return algorithm_config.algorithm_no_scratch(); + algo_desc = algorithm_config.algorithm_no_scratch(); + SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionForwardWorkspace( + stream, cudnn, input_nd, filter, conv, + output_nd, &algo_desc, scratch_allocator)); + return algo_desc; } port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardDataAlgorithm( @@ -2181,7 +2202,7 @@ port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardDataAlgorithm( } auto scratch_or = AllocateCudnnConvolutionBackwardDataWorkspace( - stream, cudnn, algo_desc, input_nd, filter, conv, output_nd, + stream, cudnn, input_nd, filter, conv, output_nd, &algo_desc, scratch_allocator); if (scratch_or.ok()) { @@ -2198,11 +2219,11 @@ port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardDataAlgorithm( "while a secondary algorithm is not provided."); } - SE_ASSIGN_OR_RETURN( - *scratch, AllocateCudnnConvolutionBackwardDataWorkspace( - stream, cudnn, algorithm_config.algorithm_no_scratch(), - input_nd, filter, conv, output_nd, scratch_allocator)); - return algorithm_config.algorithm_no_scratch(); + algo_desc = algorithm_config.algorithm_no_scratch(); + SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionBackwardDataWorkspace( + stream, cudnn, input_nd, filter, conv, + output_nd, &algo_desc, scratch_allocator)); + return algo_desc; } port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardFilterAlgorithm( @@ -2230,7 +2251,7 @@ port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardFilterAlgorithm( } auto scratch_or = AllocateCudnnConvolutionBackwardFilterWorkspace( - stream, cudnn, algo_desc, input_nd, filter, conv, output_nd, + stream, cudnn, input_nd, filter, conv, output_nd, &algo_desc, scratch_allocator); if (scratch_or.ok()) { @@ -2247,11 +2268,11 @@ port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardFilterAlgorithm( "while a secondary algorithm is not provided."); } - SE_ASSIGN_OR_RETURN(*scratch, - AllocateCudnnConvolutionBackwardFilterWorkspace( - stream, cudnn, algorithm_config.algorithm(), input_nd, - filter, conv, output_nd, scratch_allocator)); - return algorithm_config.algorithm_no_scratch(); + algo_desc = algorithm_config.algorithm_no_scratch(); + SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionBackwardFilterWorkspace( + stream, cudnn, input_nd, filter, conv, + output_nd, &algo_desc, scratch_allocator)); + return algo_desc; } // A helper class to set env-vars and choose options for cudnn-related @@ -2480,10 +2501,11 @@ port::Status CudnnSupport::DoFusedConvolveImpl( DeviceMemory<Type>* output_data, ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { - if (activation_mode != dnn::ActivationMode::kRelu) { + if (activation_mode != dnn::ActivationMode::kRelu && + activation_mode != dnn::ActivationMode::kNone) { return port::Status(port::error::INVALID_ARGUMENT, "cudnnConvolutionBiasActivationForward() only supports " - "Relu activation."); + "Relu or None activation."); } CudnnTensorDescriptor conv_input_nd( @@ -3075,8 +3097,7 @@ port::Status CudnnSupport::DoConvolveBackwardDataImpl( } // Cudnn 7.1.4 has a bug if the workspace of the following convolution is not - // zero-initialized. - // TODO(timshen): Add an nvbugs/ link. + // zero-initialized, nvbugs/2254619. if (CUDNN_VERSION >= 7000 && algorithm_config.algorithm().algo_id() == CUDNN_CONVOLUTION_BWD_DATA_ALGO_1 && @@ -3603,7 +3624,7 @@ bool CudnnSupport::DoPoolForward( const dnn::BatchDescriptor& input_dimensions, const DeviceMemory<double>& input_data, const dnn::BatchDescriptor& output_dimensions, - DeviceMemory<double>* output_data) { + DeviceMemory<double>* output_data, ScratchAllocator* workspace_allocator) { // Alpha is the scaling factor for input. double alpha = 1.0; // Beta is the scaling factor for output. @@ -3628,7 +3649,7 @@ bool CudnnSupport::DoPoolForward( const dnn::BatchDescriptor& input_dimensions, const DeviceMemory<float>& input_data, const dnn::BatchDescriptor& output_dimensions, - DeviceMemory<float>* output_data) { + DeviceMemory<float>* output_data, ScratchAllocator* workspace_allocator) { // Alpha is the scaling factor for input. float alpha = 1.0; // Beta is the scaling factor for output. @@ -3653,7 +3674,8 @@ bool CudnnSupport::DoPoolForward( const dnn::BatchDescriptor& input_dimensions, const DeviceMemory<Eigen::half>& input_data, const dnn::BatchDescriptor& output_dimensions, - DeviceMemory<Eigen::half>* output_data) { + DeviceMemory<Eigen::half>* output_data, + ScratchAllocator* workspace_allocator) { // Alpha is the scaling factor for input. float alpha = 1.0; // Beta is the scaling factor for output. @@ -3679,7 +3701,8 @@ bool CudnnSupport::DoPoolBackward( const dnn::BatchDescriptor& output_dimensions, const DeviceMemory<double>& output_data, const DeviceMemory<double>& input_diff_data, - DeviceMemory<double>* output_diff_data) { + DeviceMemory<double>* output_diff_data, + ScratchAllocator* workspace_allocator) { // Alpha is the scaling factor for input. double alpha = 1.0; // Beta is the scaling factor for output. @@ -3708,7 +3731,8 @@ bool CudnnSupport::DoPoolBackward( const dnn::BatchDescriptor& output_dimensions, const DeviceMemory<float>& output_data, const DeviceMemory<float>& input_diff_data, - DeviceMemory<float>* output_diff_data) { + DeviceMemory<float>* output_diff_data, + ScratchAllocator* workspace_allocator) { // Alpha is the scaling factor for input. float alpha = 1.0; // Beta is the scaling factor for output. @@ -3737,7 +3761,8 @@ bool CudnnSupport::DoPoolBackward( const dnn::BatchDescriptor& output_dimensions, const DeviceMemory<Eigen::half>& output_data, const DeviceMemory<Eigen::half>& input_diff_data, - DeviceMemory<Eigen::half>* output_diff_data) { + DeviceMemory<Eigen::half>* output_diff_data, + ScratchAllocator* workspace_allocator) { // Alpha is the scaling factor for input. float alpha = 1.0; // Beta is the scaling factor for output. @@ -3806,7 +3831,8 @@ bool CudnnSupport::DoNormalizeBackwardWithDimensions( const dnn::BatchDescriptor& dimensions, const DeviceMemory<float>& raw_data, const DeviceMemory<float>& normalized_data, const DeviceMemory<float>& normalized_variable_gradient, - DeviceMemory<float>* raw_variable_gradient) { + DeviceMemory<float>* raw_variable_gradient, + ScratchAllocator* workspace_allocator) { // Check for unsupported modes. if (normalize_descriptor.wrap_around()) { LOG(ERROR) << "CUDA LRN does not support cudnn-around mode"; diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h index c924d41cb5..9d88f971bb 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.h +++ b/tensorflow/stream_executor/cuda/cuda_dnn.h @@ -515,21 +515,24 @@ class CudnnSupport : public dnn::DnnSupport { const dnn::BatchDescriptor& input_dimensions, const DeviceMemory<double>& input_data, const dnn::BatchDescriptor& output_dimensions, - DeviceMemory<double>* output_data) override; + DeviceMemory<double>* output_data, + ScratchAllocator* workspace_allocator) override; bool DoPoolForward(Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions, const dnn::BatchDescriptor& input_dimensions, const DeviceMemory<float>& input_data, const dnn::BatchDescriptor& output_dimensions, - DeviceMemory<float>* output_data) override; + DeviceMemory<float>* output_data, + ScratchAllocator* workspace_allocator) override; bool DoPoolForward(Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions, const dnn::BatchDescriptor& input_dimensions, const DeviceMemory<Eigen::half>& input_data, const dnn::BatchDescriptor& output_dimensions, - DeviceMemory<Eigen::half>* output_data) override; + DeviceMemory<Eigen::half>* output_data, + ScratchAllocator* workspace_allocator) override; bool DoPoolBackward(Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions, @@ -538,7 +541,8 @@ class CudnnSupport : public dnn::DnnSupport { const dnn::BatchDescriptor& output_dimensions, const DeviceMemory<double>& output_data, const DeviceMemory<double>& input_diff_data, - DeviceMemory<double>* output_diff_data) override; + DeviceMemory<double>* output_diff_data, + ScratchAllocator* workspace_allocator) override; bool DoPoolBackward(Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions, @@ -547,7 +551,8 @@ class CudnnSupport : public dnn::DnnSupport { const dnn::BatchDescriptor& output_dimensions, const DeviceMemory<float>& output_data, const DeviceMemory<float>& input_diff_data, - DeviceMemory<float>* output_diff_data) override; + DeviceMemory<float>* output_diff_data, + ScratchAllocator* workspace_allocator) override; bool DoPoolBackward(Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions, @@ -556,7 +561,8 @@ class CudnnSupport : public dnn::DnnSupport { const dnn::BatchDescriptor& output_dimensions, const DeviceMemory<Eigen::half>& output_data, const DeviceMemory<Eigen::half>& input_diff_data, - DeviceMemory<Eigen::half>* output_diff_data) override; + DeviceMemory<Eigen::half>* output_diff_data, + ScratchAllocator* workspace_allocator) override; bool DoNormalize(Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor, @@ -575,7 +581,8 @@ class CudnnSupport : public dnn::DnnSupport { const DeviceMemory<float>& raw_data, const DeviceMemory<float>& normalized_data, const DeviceMemory<float>& normalized_variable_gradient, - DeviceMemory<float>* raw_variable_gradient) override; + DeviceMemory<float>* raw_variable_gradient, + ScratchAllocator* workspace_allocator) override; bool DoDepthConcatenate( Stream* stream, port::ArraySlice<dnn::BatchDescriptor> input_dimensions, diff --git a/tensorflow/stream_executor/cuda/cuda_driver.cc b/tensorflow/stream_executor/cuda/cuda_driver.cc index d508f6594a..f982f34b98 100644 --- a/tensorflow/stream_executor/cuda/cuda_driver.cc +++ b/tensorflow/stream_executor/cuda/cuda_driver.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/stream_executor/lib/human_readable.h" #include "tensorflow/stream_executor/lib/inlined_vector.h" #include "tensorflow/stream_executor/lib/notification.h" +#include "tensorflow/stream_executor/lib/ptr_util.h" #include "tensorflow/stream_executor/lib/stacktrace.h" #include "tensorflow/stream_executor/lib/static_threadlocal.h" #include "tensorflow/stream_executor/lib/strcat.h" @@ -66,14 +67,17 @@ class CreatedContexts { return Live()->find(context) != Live()->end(); } - // Adds context to the live set. + // Adds context to the live set, or returns it if it's already present. static CudaContext* Add(CUcontext context) { CHECK(context != nullptr); mutex_lock lock(mu_); - auto cuda_context = new CudaContext(context, next_id_++); - Live()->insert( - std::make_pair(context, std::unique_ptr<CudaContext>(cuda_context))); - return cuda_context; + auto insert_result = Live()->insert(std::make_pair(context, nullptr)); + auto it = insert_result.first; + if (insert_result.second) { + // context was not present in the map. Add it. + it->second = MakeUnique<CudaContext>(context, next_id_++); + } + return it->second.get(); } // Removes context from the live set. @@ -102,117 +106,16 @@ class CreatedContexts { /* static */ int64 CreatedContexts::next_id_ = 1; // 0 means "no context" // Formats CUresult to output prettified values into a log stream. -// Error summaries taken from: -// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html#group__CUDA__TYPES_1gc6c391505e117393cc2558fff6bfc2e9 -// -// TODO(leary) switch to cuGetErrorName when updated cuda.h is available. string ToString(CUresult result) { -#define OSTREAM_CUDA_ERROR(__name) \ - case CUDA_ERROR_##__name: \ - return "CUDA_ERROR_" #__name; - -/////////////// -// NOTE: here we specify return code values outside of the enum explicitly -// because our in-tree cuda.h is from the CUDA 5.5 SDK, but CUDA 6.0+ driver -// libraries are deployed in the fleet these error codes are backwards -// compatible, but if we see a "new" one, we want to be able to identify it in -// the logs. -// -// Once we get a cuda.h that has cuGetErrorName (TODO is above) we can -// eliminate this function and just rely on the driver to provide us these -// strings. -// -// NOTE: "Must reboot all context" below is shorthand for, "must -// destroy/recreate the offending context and any allocation which come from -// it if you are to continue using CUDA." -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wswitch" - switch (result) { - OSTREAM_CUDA_ERROR(INVALID_VALUE) - OSTREAM_CUDA_ERROR(OUT_OF_MEMORY) - OSTREAM_CUDA_ERROR(NOT_INITIALIZED) - OSTREAM_CUDA_ERROR(DEINITIALIZED) - OSTREAM_CUDA_ERROR(NO_DEVICE) - OSTREAM_CUDA_ERROR(INVALID_DEVICE) - OSTREAM_CUDA_ERROR(INVALID_IMAGE) - OSTREAM_CUDA_ERROR(INVALID_CONTEXT) - OSTREAM_CUDA_ERROR(INVALID_HANDLE) - OSTREAM_CUDA_ERROR(NOT_FOUND) - OSTREAM_CUDA_ERROR(NOT_READY) - OSTREAM_CUDA_ERROR(NO_BINARY_FOR_GPU) - - // Encountered an uncorrectable ECC error during execution. - OSTREAM_CUDA_ERROR(ECC_UNCORRECTABLE) - - // Load/store on an invalid address. Must reboot all context. - case 700: - return "CUDA_ERROR_ILLEGAL_ADDRESS"; - // Passed too many / wrong arguments, too many threads for register count. - case 701: - return "CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES"; - // Kernel took too long to execute. - case 702: - return "CUDA_ERROR_LAUNCH_TIMEOUT"; - // Kernel launch uses an incompatible texturing mode. - case 703: - return "CUDA_ERROR_LAUNCH_INCOMPATIBLE_TEXTURING"; - // Trying to re-enable peer access that already has it enabled. - case 704: - return "CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED"; - // Trying to disable peer access that has not yet been enabled. - case 705: - return "CUDA_ERROR_PEER_ACCESS_NOT_ENABLED"; - // Primary context for the specified device has already been initialized. - case 708: - return "CUDA_ERROR_PRIMARY_CONTEXT_ACTIVE"; - // Context current to calling thread has been destroyed or is a primary - // context that has not yet been initialized. - case 709: - return "CUDA_ERROR_CONTEXT_IS_DESTROYED"; - // Device-side assert triggered during kernel execution. Must reboot all - // context. - case 710: - return "CUDA_ERROR_ASSERT"; - // Hardware resources to enable peer access have been exhausted. - case 711: - return "CUDA_ERROR_TOO_MANY_PEERS"; - // Memory range has already been registered. - case 712: - return "CUDA_ERROR_HOST_MEMORY_ALREADY_REGISTERED"; - // Pointer does not correspond to any currently registered memory region. - case 713: - return "CUDA_ERROR_HOST_MEMORY_NOT_REGISTERED"; - // Due to stack corruption or exceeding stack size limit. Must reboot all - // context. - case 714: - return "CUDA_ERROR_HARDWARE_STACK_ERROR"; - case 715: - return "CUDA_ERROR_ILLEGAL_INSTRUCTION"; - // Load/store on an unaligned memory address. Must reboot all context. - case 716: - return "CUDA_ERROR_MISALIGNED_ADDRESS"; - // Device instruction with specific address space given address not - // belonging to allowed address space. Must reboot all context. - case 717: - return "CUDA_ERROR_INVALID_ADDRESS_SPACE"; - // Device program counter wrapped its address space. Must reboot all - // context. - case 718: - return "CUDA_ERROR_INVALID_PC"; - // Exception on device while executing a kernel; e.g. deref invalid device - // pointer, accessing OOB shared memory. Must reboot all context. - case 719: - return "CUDA_ERROR_LAUNCH_FAILED"; - - OSTREAM_CUDA_ERROR(CONTEXT_ALREADY_IN_USE) - OSTREAM_CUDA_ERROR(PEER_ACCESS_UNSUPPORTED) - OSTREAM_CUDA_ERROR(NOT_PERMITTED) - OSTREAM_CUDA_ERROR(NOT_SUPPORTED) - OSTREAM_CUDA_ERROR(UNKNOWN) // Unknown internal error to CUDA. - default: - return port::StrCat("CUresult(", static_cast<int>(result), ")"); + const char *error_name; + if (cuGetErrorName(result, &error_name)) { + return port::StrCat("UNKNOWN ERROR (", static_cast<int>(result), ")"); + } + const char *error_string; + if (cuGetErrorString(result, &error_string)) { + return error_name; } -#pragma GCC diagnostic pop + return port::StrCat(error_name, ": ", error_string); } // Returns the current context and checks that it is in the set of CUDA contexts @@ -528,7 +431,7 @@ bool DeviceOptionsToContextFlags(const DeviceOptions &device_options, *context = CreatedContexts::Add(new_context); CHECK(*context != nullptr) << "success in this call must entail non-null result"; - VLOG(2) << "created context " << context << " for this thread"; + VLOG(2) << "created or reused context " << context << " for this thread"; return port::Status::OK(); } diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc index f11022ef1d..73f05b94db 100644 --- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc +++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc @@ -206,6 +206,48 @@ static string GetBinaryDir(bool strip_exe) { return exe_path; } +bool CUDAExecutor::LoadModuleFromCuBin(const char *cubin, CUmodule *module) { + uint64_t module_refcount; + std::tie(*module, module_refcount) = gpu_binary_to_module_[cubin]; + + if (*module == nullptr) { + auto load_status = CUDADriver::LoadCubin(context_, cubin, module); + if (!load_status.ok()) { + LOG(ERROR) << "failed to load CUBIN: " << load_status; + return false; + } + module_refcount = 1; + VLOG(3) << "Loaded CUBIN " << static_cast<const void *>(cubin) + << " as module " << *module; + } else { + ++module_refcount; + VLOG(3) << "CUBIN " << static_cast<const void *>(cubin) + << " is already loaded as module " << *module; + } + gpu_binary_to_module_[cubin] = {*module, module_refcount}; + return true; +} + +bool CUDAExecutor::LoadModuleFromPtx(const char *ptx, CUmodule *module) { + uint64_t module_refcount; + std::tie(*module, module_refcount) = gpu_binary_to_module_[ptx]; + + if (*module == nullptr) { + if (!CUDADriver::LoadPtx(context_, ptx, module)) { + return false; + } + VLOG(3) << "Loaded PTX " << static_cast<const void *>(ptx) << " as module " + << *module; + module_refcount = 1; + } else { + ++module_refcount; + VLOG(3) << "PTX " << static_cast<const void *>(ptx) + << " is already loaded as module " << module; + } + gpu_binary_to_module_[ptx] = {*module, module_refcount}; + return true; +} + bool CUDAExecutor::GetKernel(const MultiKernelLoaderSpec &spec, KernelBase *kernel) { CUDAKernel *cuda_kernel = AsCUDAKernel(kernel); @@ -215,28 +257,13 @@ bool CUDAExecutor::GetKernel(const MultiKernelLoaderSpec &spec, VLOG(3) << "GetKernel on kernel " << kernel << " : " << kernel->name(); if (spec.has_cuda_cubin_in_memory()) { + mutex_lock lock{in_memory_modules_mu_}; kernelname = &spec.cuda_cubin_in_memory().kernelname(); const char *cubin = spec.cuda_cubin_in_memory().bytes(); - mutex_lock lock{in_memory_modules_mu_}; - uint64_t module_refcount; - std::tie(module, module_refcount) = gpu_binary_to_module_[cubin]; - - if (module == nullptr) { - auto load_status = CUDADriver::LoadCubin(context_, cubin, &module); - if (!load_status.ok()) { - LOG(ERROR) << "failed to load CUBIN: " << load_status; - return false; - } - module_refcount = 1; - VLOG(3) << "Loaded CUBIN " << static_cast<const void *>(cubin) - << " as module " << module; - } else { - ++module_refcount; - VLOG(3) << "CUBIN " << static_cast<const void *>(cubin) - << " is already loaded as module " << module; + if (!LoadModuleFromCuBin(cubin, &module)) { + return false; } kernel_to_gpu_binary_[kernel] = cubin; - gpu_binary_to_module_[cubin] = {module, module_refcount}; } else if (spec.has_cuda_ptx_in_memory()) { kernelname = &spec.cuda_ptx_in_memory().kernelname(); @@ -254,24 +281,10 @@ bool CUDAExecutor::GetKernel(const MultiKernelLoaderSpec &spec, } mutex_lock lock{in_memory_modules_mu_}; - uint64_t module_refcount; - std::tie(module, module_refcount) = gpu_binary_to_module_[ptx]; - - if (module == nullptr) { - if (!CUDADriver::LoadPtx(context_, ptx, &module)) { - LOG(ERROR) << "failed to load PTX for kernel " << *kernelname; - return false; - } - VLOG(3) << "Loaded PTX " << static_cast<const void *>(ptx) - << " as module " << module; - module_refcount = 1; - } else { - ++module_refcount; - VLOG(3) << "PTX " << static_cast<const void *>(ptx) - << " is already loaded as module " << module; + if (!LoadModuleFromPtx(ptx, &module)) { + return false; } kernel_to_gpu_binary_[kernel] = ptx; - gpu_binary_to_module_[ptx] = {module, module_refcount}; } else { LOG(WARNING) << "no method of loading CUDA kernel provided"; return false; @@ -295,6 +308,23 @@ bool CUDAExecutor::GetKernel(const MultiKernelLoaderSpec &spec, return true; } +bool CUDAExecutor::UnloadGpuBinary(const void *gpu_binary) { + auto module_it = gpu_binary_to_module_.find(gpu_binary); + if (gpu_binary_to_module_.end() == module_it) { + VLOG(3) << "No loaded CUDA module for " << gpu_binary; + return false; + } + auto &module = module_it->second.first; + auto &refcount = module_it->second.second; + VLOG(3) << "Found CUDA module " << module << " with refcount " << refcount; + if (--refcount == 0) { + VLOG(3) << "Unloading CUDA module " << module; + CUDADriver::UnloadModule(context_, module); + gpu_binary_to_module_.erase(module_it); + } + return true; +} + void CUDAExecutor::UnloadKernel(const KernelBase *kernel) { VLOG(3) << "Unloading kernel " << kernel << " : " << kernel->name(); @@ -307,25 +337,52 @@ void CUDAExecutor::UnloadKernel(const KernelBase *kernel) { } VLOG(3) << "Kernel " << kernel << " : " << kernel->name() << " has loaded GPU code " << gpu_binary_it->second; - auto module_it = gpu_binary_to_module_.find(gpu_binary_it->second); - if (gpu_binary_to_module_.end() == module_it) { - VLOG(3) << "Kernel " << kernel << " : " << kernel->name() - << " has no loaded CUDA module."; - return; // This kernel never loaded any modules - } - auto &module = module_it->second.first; - auto &refcount = module_it->second.second; - VLOG(3) << "Kernel " << kernel << " : " << kernel->name() - << " has loaded GPU code " << gpu_binary_it->second - << " into CUDA module " << module << " with refcount " << refcount; - if (--refcount == 0) { - VLOG(3) << "Unloading CUDA module " << module; - CUDADriver::UnloadModule(context_, module); - gpu_binary_to_module_.erase(module_it); - } + UnloadGpuBinary(gpu_binary_it->second); kernel_to_gpu_binary_.erase(gpu_binary_it); } +bool CUDAExecutor::LoadModule(const MultiModuleLoaderSpec &spec, + ModuleHandle *module_handle) { + // In CUDAExecutor we store the pointer to the GPU binary (PTX or CUBIN) as + // ModuleHandle::id(). + CUmodule cu_module; + if (spec.has_cuda_cubin_in_memory()) { + mutex_lock lock{in_memory_modules_mu_}; + if (!LoadModuleFromCuBin( + reinterpret_cast<const char *>(spec.cuda_cubin_in_memory().data()), + &cu_module)) { + return false; + } + *module_handle = ModuleHandle(const_cast<void *>( + static_cast<const void *>(spec.cuda_cubin_in_memory().data()))); + return true; + } else if (spec.has_cuda_ptx_in_memory()) { + if (cc_major_ == 0 && cc_minor_ == 0) { + return false; + } + + if (!spec.cuda_ptx_in_memory()) { + return false; + } + + mutex_lock lock{in_memory_modules_mu_}; + if (!LoadModuleFromPtx(spec.cuda_ptx_in_memory(), &cu_module)) { + return false; + } + *module_handle = ModuleHandle(const_cast<void *>( + static_cast<const void *>(spec.cuda_ptx_in_memory()))); + return true; + } + LOG(WARNING) << "no method of loading CUDA module provided"; + return false; +} + +bool CUDAExecutor::UnloadModule(ModuleHandle module_handle) { + const char *gpu_binary = reinterpret_cast<const char *>(module_handle.id()); + mutex_lock lock{in_memory_modules_mu_}; + return UnloadGpuBinary(gpu_binary); +} + bool CUDAExecutor::GetKernelMetadata(CUDAKernel *cuda_kernel, KernelMetadata *kernel_metadata) { int value; @@ -783,16 +840,26 @@ bool CUDAExecutor::DeviceMemoryUsage(int64 *free, int64 *total) const { return CUDADriver::GetDeviceMemoryInfo(context_, free, total); } -bool CUDAExecutor::GetSymbol(const string& symbol_name, void **mem, +bool CUDAExecutor::GetSymbol(const string &symbol_name, + ModuleHandle module_handle, void **mem, size_t *bytes) { + auto lookup_in_module = [&](CUmodule module) { + CHECK(module != nullptr); + return CUDADriver::GetModuleSymbol(context_, module, symbol_name.c_str(), + reinterpret_cast<CUdeviceptr *>(mem), + bytes); + }; + { // give limited scope to mutex_lock mutex_lock lock{in_memory_modules_mu_}; + if (static_cast<bool>(module_handle)) { + auto it = gpu_binary_to_module_.find(module_handle.id()); + CHECK(it != gpu_binary_to_module_.end()); + return lookup_in_module(it->second.first); + } + for (auto &it : gpu_binary_to_module_) { - CUmodule module = it.second.first; - CHECK(module != nullptr); - if (CUDADriver::GetModuleSymbol(context_, module, symbol_name.c_str(), - reinterpret_cast<CUdeviceptr *>(mem), - bytes)) { + if (lookup_in_module(it.second.first)) { return true; } } @@ -844,7 +911,7 @@ CUDAExecutor::GetTimerImplementation() { return std::unique_ptr<internal::TimerInterface>(new CUDATimer(this)); } -void *CUDAExecutor::CudaContextHack() { return context_; } +void *CUDAExecutor::GpuContextHack() { return context_; } CudaContext* CUDAExecutor::cuda_context() { return context_; } diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h index 773cbfb8a1..8a954d5461 100644 --- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h +++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h @@ -62,6 +62,9 @@ class CUDAExecutor : public internal::StreamExecutorInterface { bool GetKernel(const MultiKernelLoaderSpec &spec, KernelBase *kernel) override; void UnloadKernel(const KernelBase *kernel) override; + bool LoadModule(const MultiModuleLoaderSpec &spec, + ModuleHandle *module_handle) override; + bool UnloadModule(ModuleHandle module_handle) override; bool Launch(Stream *stream, const ThreadDim &thread_dims, const BlockDim &block_dims, const KernelBase &k, @@ -175,7 +178,8 @@ class CUDAExecutor : public internal::StreamExecutorInterface { // Search for the symbol and returns a device pointer and size. // Returns false if symbol does not exist. - bool GetSymbol(const string& symbol_name, void **mem, size_t *bytes) override; + bool GetSymbol(const string &symbol_name, ModuleHandle module_handle, + void **mem, size_t *bytes) override; DeviceDescription *PopulateDeviceDescription() const override; @@ -210,7 +214,7 @@ class CUDAExecutor : public internal::StreamExecutorInterface { std::unique_ptr<internal::TimerInterface> GetTimerImplementation() override; - void *CudaContextHack() override; + void *GpuContextHack() override; CudaContext* cuda_context(); @@ -239,6 +243,16 @@ class CUDAExecutor : public internal::StreamExecutorInterface { void VlogOccupancyInfo(const KernelBase &kernel, const ThreadDim &thread_dims, const BlockDim &block_dims); + bool LoadModuleFromCuBin(const char *cubin, CUmodule *module) + EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_); + + // Loads the PTX text `ptx` as a CUDA module. `ptx` must be null terminated. + bool LoadModuleFromPtx(const char *ptx, CUmodule *module) + EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_); + + bool UnloadGpuBinary(const void *gpu_binary) + EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_); + // Guards the in-memory-module mapping. mutex in_memory_modules_mu_; diff --git a/tensorflow/stream_executor/cuda/cuda_stream.h b/tensorflow/stream_executor/cuda/cuda_stream.h index 02edff6431..bb8bda4755 100644 --- a/tensorflow/stream_executor/cuda/cuda_stream.h +++ b/tensorflow/stream_executor/cuda/cuda_stream.h @@ -40,8 +40,8 @@ class CUDAStream : public internal::StreamInterface { // Note: teardown is handled by a parent's call to DeallocateStream. ~CUDAStream() override {} - void *CudaStreamHack() override { return cuda_stream_; } - void **CudaStreamMemberHack() override { + void *GpuStreamHack() override { return cuda_stream_; } + void **GpuStreamMemberHack() override { return reinterpret_cast<void **>(&cuda_stream_); } diff --git a/tensorflow/stream_executor/dnn.cc b/tensorflow/stream_executor/dnn.cc index 82aa8ceb32..2a30f922bc 100644 --- a/tensorflow/stream_executor/dnn.cc +++ b/tensorflow/stream_executor/dnn.cc @@ -117,6 +117,8 @@ string FilterLayoutString(FilterLayout layout) { switch (layout) { case FilterLayout::kOutputInputYX: return "OutputInputYX"; + case FilterLayout::kOutputYXInput: + return "OutputYXInput"; case FilterLayout::kOutputInputYX4: return "OutputInputYX4"; case FilterLayout::kInputYXOutput: diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h index 9eca5abe1a..9abfa1db6a 100644 --- a/tensorflow/stream_executor/dnn.h +++ b/tensorflow/stream_executor/dnn.h @@ -713,15 +713,23 @@ class PoolingDescriptor { class AlgorithmDesc { public: typedef int64 Index; - AlgorithmDesc() : algo_(kDefaultAlgorithm), tensor_ops_enabled_(true) {} + AlgorithmDesc() + : algo_(kDefaultAlgorithm), tensor_ops_enabled_(true), scratch_size_(0) {} AlgorithmDesc(Index a, bool use_tensor_ops) - : algo_(a), tensor_ops_enabled_(use_tensor_ops) {} + : algo_(a), tensor_ops_enabled_(use_tensor_ops), scratch_size_(0) {} + AlgorithmDesc(Index a, bool use_tensor_ops, size_t scratch_size) + : algo_(a), + tensor_ops_enabled_(use_tensor_ops), + scratch_size_(scratch_size) {} bool is_default() const { return algo_ == kDefaultAlgorithm; } bool tensor_ops_enabled() const { return tensor_ops_enabled_; } Index algo_id() const { return algo_; } + size_t scratch_size() const { return scratch_size_; } + void set_scratch_size(size_t val) { scratch_size_ = val; } bool operator==(const AlgorithmDesc& other) const { return this->algo_ == other.algo_ && - this->tensor_ops_enabled_ == other.tensor_ops_enabled_; + this->tensor_ops_enabled_ == other.tensor_ops_enabled_ && + this->scratch_size_ == other.scratch_size_; } uint64 hash() const; @@ -729,6 +737,7 @@ class AlgorithmDesc { enum { kDefaultAlgorithm = -1 }; Index algo_; bool tensor_ops_enabled_; + size_t scratch_size_; }; // Describes the result from a perf experiment. @@ -1552,14 +1561,16 @@ class DnnSupport { const dnn::BatchDescriptor& input_dimensions, const DeviceMemory<float>& input_data, const dnn::BatchDescriptor& output_dimensions, - DeviceMemory<float>* output_data) = 0; + DeviceMemory<float>* output_data, + ScratchAllocator* workspace_allocator) = 0; virtual bool DoPoolForward(Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions, const dnn::BatchDescriptor& input_dimensions, const DeviceMemory<double>& input_data, const dnn::BatchDescriptor& output_dimensions, - DeviceMemory<double>* output_data) { + DeviceMemory<double>* output_data, + ScratchAllocator* workspace_allocator) { LOG(FATAL) << "DoPoolForward not implemented for double."; return false; } @@ -1569,7 +1580,8 @@ class DnnSupport { const dnn::BatchDescriptor& input_dimensions, const DeviceMemory<Eigen::half>& input_data, const dnn::BatchDescriptor& output_dimensions, - DeviceMemory<Eigen::half>* output_data) { + DeviceMemory<Eigen::half>* output_data, + ScratchAllocator* workspace_allocator) { LOG(FATAL) << "DoPoolForward not implemented for float16."; return false; } @@ -1582,7 +1594,8 @@ class DnnSupport { const dnn::BatchDescriptor& output_dimensions, const DeviceMemory<double>& output_data, const DeviceMemory<double>& input_diff_data, - DeviceMemory<double>* output_diff_data) { + DeviceMemory<double>* output_diff_data, + ScratchAllocator* workspace_allocator) { LOG(FATAL) << "DoPoolBackward not implemented."; return false; } @@ -1594,7 +1607,8 @@ class DnnSupport { const dnn::BatchDescriptor& output_dimensions, const DeviceMemory<float>& output_data, const DeviceMemory<float>& input_diff_data, - DeviceMemory<float>* output_diff_data) { + DeviceMemory<float>* output_diff_data, + ScratchAllocator* workspace_allocator) { LOG(FATAL) << "DoPoolBackward not implemented."; return false; } @@ -1606,7 +1620,8 @@ class DnnSupport { const dnn::BatchDescriptor& output_dimensions, const DeviceMemory<Eigen::half>& output_data, const DeviceMemory<Eigen::half>& input_diff_data, - DeviceMemory<Eigen::half>* output_diff_data) { + DeviceMemory<Eigen::half>* output_diff_data, + ScratchAllocator* workspace_allocator) { LOG(FATAL) << "DoPoolBackward not implemented."; return false; } @@ -1653,7 +1668,8 @@ class DnnSupport { const DeviceMemory<float>& raw_data, const DeviceMemory<float>& normalized_data, const DeviceMemory<float>& normalized_variable_gradient, - DeviceMemory<float>* raw_variable_gradient) { + DeviceMemory<float>* raw_variable_gradient, + ScratchAllocator* workspace_allocator) { return false; } diff --git a/tensorflow/stream_executor/host/host_gpu_executor.cc b/tensorflow/stream_executor/host/host_gpu_executor.cc index 3cd97b3cf1..8adf739b17 100644 --- a/tensorflow/stream_executor/host/host_gpu_executor.cc +++ b/tensorflow/stream_executor/host/host_gpu_executor.cc @@ -93,7 +93,7 @@ bool HostExecutor::MemcpyDeviceToDevice(Stream *stream, // the nature of the HostExecutor) memcpy on the stream (HostStream) // associated with the HostExecutor. AsHostStream(stream)->EnqueueTask( - [src_mem, dst_mem, size]() { memcpy(src_mem, dst_mem, size); }); + [src_mem, dst_mem, size]() { memcpy(dst_mem, src_mem, size); }); return true; } diff --git a/tensorflow/stream_executor/host/host_gpu_executor.h b/tensorflow/stream_executor/host/host_gpu_executor.h index e82f57569f..7ba1f18101 100644 --- a/tensorflow/stream_executor/host/host_gpu_executor.h +++ b/tensorflow/stream_executor/host/host_gpu_executor.h @@ -88,7 +88,7 @@ class HostExecutor : public internal::StreamExecutorInterface { uint64 size) override; // No "synchronize all activity" implemented for this platform at the moment. - bool SynchronizeAllActivity() override { return false; } + bool SynchronizeAllActivity() override { return true; } bool SynchronousMemZero(DeviceMemoryBase *location, uint64 size) override; bool SynchronousMemSet(DeviceMemoryBase *location, int value, @@ -202,7 +202,7 @@ class HostExecutor : public internal::StreamExecutorInterface { return std::unique_ptr<internal::TimerInterface>(new HostTimer()); } - void *CudaContextHack() override { return nullptr; } + void *GpuContextHack() override { return nullptr; } private: const PluginConfig plugin_config_; diff --git a/tensorflow/stream_executor/host/host_stream.cc b/tensorflow/stream_executor/host/host_stream.cc index 5a7d3b3dd4..bfbfb56cd7 100644 --- a/tensorflow/stream_executor/host/host_stream.cc +++ b/tensorflow/stream_executor/host/host_stream.cc @@ -28,18 +28,28 @@ HostStream::HostStream() HostStream::~HostStream() {} bool HostStream::EnqueueTask(std::function<void()> task) { + struct NotifiedTask { + HostStream* stream; + std::function<void()> task; + + void operator()() { + task(); + // Destroy the task before unblocking its waiters, as BlockHostUntilDone() + // should guarantee that all tasks are destroyed. + task = std::function<void()>(); + { + mutex_lock lock(stream->mu_); + --stream->pending_tasks_; + } + stream->completion_condition_.notify_all(); + } + }; + { mutex_lock lock(mu_); ++pending_tasks_; } - host_executor_->Schedule([this, task]() { - task(); - { - mutex_lock lock(mu_); - --pending_tasks_; - } - completion_condition_.notify_all(); - }); + host_executor_->Schedule(NotifiedTask{this, std::move(task)}); return true; } diff --git a/tensorflow/stream_executor/host/host_stream.h b/tensorflow/stream_executor/host/host_stream.h index 5d7b8a3782..be88f074cf 100644 --- a/tensorflow/stream_executor/host/host_stream.h +++ b/tensorflow/stream_executor/host/host_stream.h @@ -34,8 +34,8 @@ class HostStream : public internal::StreamInterface { bool EnqueueTask(std::function<void()> task); - void *CudaStreamHack() override { return nullptr; } - void **CudaStreamMemberHack() override { return nullptr; } + void *GpuStreamHack() override { return nullptr; } + void **GpuStreamMemberHack() override { return nullptr; } void BlockUntilDone(); diff --git a/tensorflow/stream_executor/module_spec.h b/tensorflow/stream_executor/module_spec.h new file mode 100644 index 0000000000..75bdfed2d7 --- /dev/null +++ b/tensorflow/stream_executor/module_spec.h @@ -0,0 +1,66 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_STREAM_EXECUTOR_MODULE_SPEC_H_ +#define TENSORFLOW_STREAM_EXECUTOR_MODULE_SPEC_H_ + +#include "tensorflow/stream_executor/lib/array_slice.h" +#include "tensorflow/stream_executor/lib/stringpiece.h" +#include "tensorflow/stream_executor/platform/logging.h" +#include "tensorflow/stream_executor/platform/port.h" + +namespace stream_executor { + +// Describes how to load a module on a target platform. +// +// The exact meaning of a "module" may differ from platform to platform but +// loosely speaking a module a collection of kernels and global variables. It +// corresponds to CUmodule when running on CUDA. +class MultiModuleLoaderSpec { + public: + bool has_cuda_cubin_in_memory() const { return has_cuda_cubin_in_memory_; } + port::ArraySlice<const uint8> cuda_cubin_in_memory() const { + CHECK(has_cuda_cubin_in_memory()); + return {cuda_cubin_in_memory_.data(), cuda_cubin_in_memory_.size()}; + } + + bool has_cuda_ptx_in_memory() const { return has_cuda_ptx_in_memory_; } + const char* cuda_ptx_in_memory() const { + CHECK(has_cuda_ptx_in_memory()); + return cuda_ptx_in_memory_; + } + + void AddCudaCubinInMemory(port::ArraySlice<const uint8> cubin_bytes) { + CHECK(!cubin_bytes.empty()); + has_cuda_cubin_in_memory_ = true; + cuda_cubin_in_memory_ = cubin_bytes; + } + + void AddCudaPtxInMemory(const char* ptx) { + has_cuda_ptx_in_memory_ = true; + // The CUDA driver does not like getting an empty string as PTX. + cuda_ptx_in_memory_ = *ptx ? ptx : nullptr; + } + + private: + port::ArraySlice<const uint8> cuda_cubin_in_memory_; + bool has_cuda_cubin_in_memory_ = false; + const char* cuda_ptx_in_memory_; + bool has_cuda_ptx_in_memory_ = false; +}; + +} // namespace stream_executor + +#endif // TENSORFLOW_STREAM_EXECUTOR_MODULE_SPEC_H_ diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc index 9369183133..9efd34de24 100644 --- a/tensorflow/stream_executor/stream.cc +++ b/tensorflow/stream_executor/stream.cc @@ -115,7 +115,7 @@ string ToVlogString(const DeviceMemoryBase &memory) { } string ToVlogString(const DeviceMemoryBase *memory) { - return ToVlogString(*memory); + return memory == nullptr ? "null" : ToVlogString(*memory); } string ToVlogString(const Eigen::half &h) { @@ -211,13 +211,14 @@ string CallStr(const char *function_name, Stream *stream, // constructing all the strings in params is expensive. CHECK(VLOG_IS_ON(1)); - string str = port::StrCat("Called Stream::", function_name, "("); + string str = port::StrCat(stream->DebugStreamPointers(), + " Called Stream::", function_name, "("); const char *separator = ""; for (const auto ¶m : params) { port::StrAppend(&str, separator, param.first, "=", param.second); separator = ", "; } - port::StrAppend(&str, ") stream=", ToVlogString(stream)); + port::StrAppend(&str, ")"); if (VLOG_IS_ON(10)) { port::StrAppend(&str, " ", port::CurrentStackTrace(), "\n"); } @@ -267,6 +268,12 @@ Stream::Stream(StreamExecutor *parent, Stream::~Stream() { VLOG_CALL(); + // Ensure the stream is completed. + auto status = BlockHostUntilDone(); + if (!status.ok()) { + LOG(WARNING) << "Error blocking host until done in stream destructor: " + << status; + } temporary_memory_manager_.ForceDeallocateAll(); if (allocated_) { @@ -1377,15 +1384,16 @@ Stream &Stream::ThenPoolForward( const dnn::BatchDescriptor &input_dimensions, const DeviceMemory<double> &input_data, const dnn::BatchDescriptor &output_dimensions, - DeviceMemory<double> *output_data) { + DeviceMemory<double> *output_data, ScratchAllocator *workspace_allocator) { VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions), - PARAM(input_data), PARAM(output_dimensions), PARAM(output_data)); + PARAM(input_data), PARAM(output_dimensions), PARAM(output_data), + PARAM(workspace_allocator)); if (ok()) { if (dnn::DnnSupport *dnn = parent_->AsDnn()) { CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions, - input_data, output_dimensions, - output_data)); + input_data, output_dimensions, output_data, + workspace_allocator)); } else { SetError(); LOG(WARNING) @@ -1401,15 +1409,16 @@ Stream &Stream::ThenPoolForward( const dnn::BatchDescriptor &input_dimensions, const DeviceMemory<float> &input_data, const dnn::BatchDescriptor &output_dimensions, - DeviceMemory<float> *output_data) { + DeviceMemory<float> *output_data, ScratchAllocator *workspace_allocator) { VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions), - PARAM(input_data), PARAM(output_dimensions), PARAM(output_data)); + PARAM(input_data), PARAM(output_dimensions), PARAM(output_data), + PARAM(workspace_allocator)); if (ok()) { if (dnn::DnnSupport *dnn = parent_->AsDnn()) { CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions, - input_data, output_dimensions, - output_data)); + input_data, output_dimensions, output_data, + workspace_allocator)); } else { SetErrorAndLogNoDnnSupport(); } @@ -1422,15 +1431,17 @@ Stream &Stream::ThenPoolForward( const dnn::BatchDescriptor &input_dimensions, const DeviceMemory<Eigen::half> &input_data, const dnn::BatchDescriptor &output_dimensions, - DeviceMemory<Eigen::half> *output_data) { + DeviceMemory<Eigen::half> *output_data, + ScratchAllocator *workspace_allocator) { VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions), - PARAM(input_data), PARAM(output_dimensions), PARAM(output_data)); + PARAM(input_data), PARAM(output_dimensions), PARAM(output_data), + PARAM(workspace_allocator)); if (ok()) { if (dnn::DnnSupport *dnn = parent_->AsDnn()) { CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions, - input_data, output_dimensions, - output_data)); + input_data, output_dimensions, output_data, + workspace_allocator)); } else { SetErrorAndLogNoDnnSupport(); } @@ -1445,16 +1456,19 @@ Stream &Stream::ThenPoolBackward( const dnn::BatchDescriptor &output_dimensions, const DeviceMemory<double> &output_data, const DeviceMemory<double> &input_diff_data, - DeviceMemory<double> *output_diff_data) { + DeviceMemory<double> *output_diff_data, + ScratchAllocator *workspace_allocator) { VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions), PARAM(input_data), PARAM(output_dimensions), PARAM(output_data), - PARAM(input_diff_data), PARAM(output_diff_data)); + PARAM(input_diff_data), PARAM(output_diff_data), + PARAM(workspace_allocator)); if (ok()) { if (dnn::DnnSupport *dnn = parent_->AsDnn()) { CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions, input_data, output_dimensions, output_data, - input_diff_data, output_diff_data)); + input_diff_data, output_diff_data, + workspace_allocator)); } else { SetError(); LOG(WARNING) @@ -1472,16 +1486,19 @@ Stream &Stream::ThenPoolBackward( const dnn::BatchDescriptor &output_dimensions, const DeviceMemory<float> &output_data, const DeviceMemory<float> &input_diff_data, - DeviceMemory<float> *output_diff_data) { + DeviceMemory<float> *output_diff_data, + ScratchAllocator *workspace_allocator) { VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions), PARAM(input_data), PARAM(output_dimensions), PARAM(output_data), - PARAM(input_diff_data), PARAM(output_diff_data)); + PARAM(input_diff_data), PARAM(output_diff_data), + PARAM(workspace_allocator)); if (ok()) { if (dnn::DnnSupport *dnn = parent_->AsDnn()) { CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions, input_data, output_dimensions, output_data, - input_diff_data, output_diff_data)); + input_diff_data, output_diff_data, + workspace_allocator)); } else { SetErrorAndLogNoDnnSupport(); } @@ -1496,16 +1513,19 @@ Stream &Stream::ThenPoolBackward( const dnn::BatchDescriptor &output_dimensions, const DeviceMemory<Eigen::half> &output_data, const DeviceMemory<Eigen::half> &input_diff_data, - DeviceMemory<Eigen::half> *output_diff_data) { + DeviceMemory<Eigen::half> *output_diff_data, + ScratchAllocator *workspace_allocator) { VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions), PARAM(input_data), PARAM(output_dimensions), PARAM(output_data), - PARAM(input_diff_data), PARAM(output_diff_data)); + PARAM(input_diff_data), PARAM(output_diff_data), + PARAM(workspace_allocator)); if (ok()) { if (dnn::DnnSupport *dnn = parent_->AsDnn()) { CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions, input_data, output_dimensions, output_data, - input_diff_data, output_diff_data)); + input_diff_data, output_diff_data, + workspace_allocator)); } else { SetErrorAndLogNoDnnSupport(); } @@ -1552,16 +1572,18 @@ Stream &Stream::ThenNormalizeBackwardWithDimensions( const dnn::BatchDescriptor &dimensions, const DeviceMemory<float> &raw_data, const DeviceMemory<float> &normalized_data, const DeviceMemory<float> &normalized_variable_gradient, - DeviceMemory<float> *raw_variable_gradient) { + DeviceMemory<float> *raw_variable_gradient, + ScratchAllocator *workspace_allocator) { VLOG_CALL(PARAM(normalize_descriptor), PARAM(dimensions), PARAM(raw_data), PARAM(normalized_data), PARAM(normalized_variable_gradient), - PARAM(raw_variable_gradient)); + PARAM(raw_variable_gradient), PARAM(workspace_allocator)); if (ok()) { if (dnn::DnnSupport *dnn = parent_->AsDnn()) { CheckError(dnn->DoNormalizeBackwardWithDimensions( this, normalize_descriptor, dimensions, raw_data, normalized_data, - normalized_variable_gradient, raw_variable_gradient)); + normalized_variable_gradient, raw_variable_gradient, + workspace_allocator)); } else { SetErrorAndLogNoDnnSupport(); } @@ -1901,30 +1923,82 @@ Stream &Stream::ThenCopyDevice2HostBuffer( Stream *Stream::GetOrCreateSubStream() { mutex_lock lock(mu_); - for (auto &stream : sub_streams_) { - if (stream.second) { - stream.second = false; - return stream.first.get(); + + // Look for the first reusable sub_stream that is ok, dropping !ok sub_streams + // we encounter along the way. + for (int64 index = 0; index < sub_streams_.size();) { + std::pair<std::unique_ptr<Stream>, bool> &pair = sub_streams_[index]; + if (pair.second) { + // The sub_stream is reusable. + Stream *sub_stream = pair.first.get(); + if (sub_stream->ok()) { + VLOG(1) << DebugStreamPointers() << " reusing sub_stream " + << sub_stream->DebugStreamPointers(); + pair.second = false; + return sub_stream; + } + + // The stream is reusable and not ok. Streams have a monotonic state + // machine; the stream will remain in !ok forever. Swap it with the last + // stream and pop it off. + const int64 last = sub_streams_.size() - 1; + if (index != last) { + std::swap(pair, sub_streams_[last]); + } + sub_streams_.pop_back(); + VLOG(1) << DebugStreamPointers() << " dropped !ok sub_stream " + << sub_stream->DebugStreamPointers(); + } else { + // The sub_stream is not reusable, move on to the next one. + ++index; } } + + // No streams are reusable; create a new stream. sub_streams_.emplace_back(std::unique_ptr<Stream>{new Stream{parent_}}, false); Stream *sub_stream = sub_streams_.back().first.get(); sub_stream->Init(); CHECK(ok_) << "sub-stream failed to be initialized"; + VLOG(1) << DebugStreamPointers() << " created new sub_stream " + << sub_stream->DebugStreamPointers(); return sub_stream; } void Stream::ReturnSubStream(Stream *sub_stream) { mutex_lock lock(mu_); - for (auto &stream : sub_streams_) { - if (stream.first.get() == sub_stream) { - stream.second = true; - return; + + // Look for the sub-stream. + for (int64 index = 0; index < sub_streams_.size(); ++index) { + std::pair<std::unique_ptr<Stream>, bool> &pair = sub_streams_[index]; + if (pair.first.get() != sub_stream) { + continue; } + + // Found the sub_stream. + if (sub_stream->ok()) { + VLOG(1) << DebugStreamPointers() << " returned ok sub_stream " + << sub_stream->DebugStreamPointers(); + pair.second = true; + } else { + // The returned stream is not ok. Streams have a monotonic state + // machine; the stream will remain in !ok forever. Swap it with the last + // stream and pop it off. + VLOG(1) << DebugStreamPointers() << " returned !ok sub_stream " + << sub_stream->DebugStreamPointers(); + const int64 last = sub_streams_.size() - 1; + if (index != last) { + std::swap(pair, sub_streams_[last]); + } + sub_streams_.pop_back(); + } + return; } - LOG(FATAL) << "the sub-stream to be returned is not created by this stream"; + + LOG(FATAL) << DebugStreamPointers() + << " did not create the returned sub-stream " + << sub_stream->DebugStreamPointers(); } Stream &Stream::ThenStartTimer(Timer *t) { @@ -1933,7 +2007,8 @@ Stream &Stream::ThenStartTimer(Timer *t) { if (ok()) { CheckError(parent_->StartTimer(this, t)); } else { - LOG(INFO) << "stream " << this << " did not enqueue 'start timer': " << t; + LOG(INFO) << DebugStreamPointers() + << " did not enqueue 'start timer': " << t; } return *this; } @@ -1944,7 +2019,8 @@ Stream &Stream::ThenStopTimer(Timer *t) { if (ok()) { CheckError(parent_->StopTimer(this, t)); } else { - LOG(INFO) << "stream " << this << " did not enqueue 'stop timer': " << t; + LOG(INFO) << DebugStreamPointers() + << " did not enqueue 'stop timer': " << t; } return *this; } @@ -1957,7 +2033,8 @@ Stream &Stream::ThenWaitFor(Stream *other) { CheckError(parent_->CreateStreamDependency(this, other)); } else { SetError(); - LOG(INFO) << "stream " << this << " did not wait for stream: " << other; + LOG(INFO) << DebugStreamPointers() << " did not wait for " + << other->DebugStreamPointers(); } return *this; } @@ -1974,7 +2051,7 @@ Stream &Stream::ThenWaitFor(Event *event) { << "at fault. Monitor for further errors."; } } else { - LOG(INFO) << "stream " << this << " did not wait for an event."; + LOG(INFO) << DebugStreamPointers() << " did not wait for an event."; } return *this; } @@ -4657,6 +4734,115 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( scratch_allocator); } +Stream &Stream::ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, int lda, + int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb, int64 stride_b, + float beta, DeviceMemory<Eigen::half> *c, int ldc, int64 stride_c, + int batch_count) { + VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), + PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b), + PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc), + PARAM(stride_c), PARAM(batch_count)); + + ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float, + const DeviceMemory<Eigen::half> &, int, int64, + const DeviceMemory<Eigen::half> &, int, int64, float, + DeviceMemory<Eigen::half> *, int, int64, int> + impl; + return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa, + transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_count); +} + +Stream &Stream::ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, float alpha, const DeviceMemory<float> &a, int lda, + int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b, + float beta, DeviceMemory<float> *c, int ldc, int64 stride_c, + int batch_count) { + VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), + PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b), + PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc), + PARAM(stride_c), PARAM(batch_count)); + + ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float, + const DeviceMemory<float> &, int, int64, + const DeviceMemory<float> &, int, int64, float, + DeviceMemory<float> *, int, int64, int> + impl; + return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa, + transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_count); +} + +Stream &Stream::ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, double alpha, const DeviceMemory<double> &a, int lda, + int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b, + double beta, DeviceMemory<double> *c, int ldc, int64 stride_c, + int batch_count) { + VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), + PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b), + PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc), + PARAM(stride_c), PARAM(batch_count)); + + ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, double, + const DeviceMemory<double> &, int, int64, + const DeviceMemory<double> &, int, int64, double, + DeviceMemory<double> *, int, int64, int> + impl; + return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa, + transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_count); +} + +Stream &Stream::ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, std::complex<float> alpha, + const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a, + const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b, + std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, + int64 stride_c, int batch_count) { + VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), + PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b), + PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc), + PARAM(stride_c), PARAM(batch_count)); + + ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, + std::complex<float>, const DeviceMemory<std::complex<float>> &, + int, int64, const DeviceMemory<std::complex<float>> &, int, + int64, std::complex<float>, DeviceMemory<std::complex<float>> *, + int, int64, int> + impl; + return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa, + transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_count); +} + +Stream &Stream::ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, std::complex<double> alpha, + const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a, + const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b, + std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc, + int64 stride_c, int batch_count) { + VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), + PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b), + PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc), + PARAM(stride_c), PARAM(batch_count)); + + ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, + std::complex<double>, const DeviceMemory<std::complex<double>> &, + int, int64, const DeviceMemory<std::complex<double>> &, int, + int64, std::complex<double>, + DeviceMemory<std::complex<double>> *, int, int64, int> + impl; + return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa, + transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_count); +} + Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes) { VLOG_CALL(PARAM(seed), PARAM(seed_bytes)); @@ -4665,10 +4851,10 @@ Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes) { CheckError(rng->SetSeed(this, seed, seed_bytes)); } else { SetError(); - LOG(INFO) << "stream " << this << " unable to initialize RNG"; + LOG(INFO) << DebugStreamPointers() << " unable to initialize RNG"; } } else { - LOG(INFO) << "stream " << this + LOG(INFO) << DebugStreamPointers() << " did not set RNG seed: " << static_cast<const void *>(seed) << "; bytes: " << seed_bytes; } @@ -4683,8 +4869,9 @@ Stream &Stream::ThenPopulateRandUniform(DeviceMemory<float> *values) { CheckError(rng->DoPopulateRandUniform(this, values)); } else { SetError(); - LOG(INFO) << "attempting to perform RNG operation using StreamExecutor " - "without RNG support."; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform RNG operation using StreamExecutor" + " without RNG support."; } } return *this; @@ -4699,8 +4886,9 @@ Stream &Stream::ThenPopulateRandGaussian(float mean, float sd, CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values)); } else { SetError(); - LOG(INFO) << "attempting to perform RNG operation using StreamExecutor " - "without RNG support."; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform RNG operation using StreamExecutor" + " without RNG support."; } } return *this; @@ -4715,8 +4903,9 @@ Stream &Stream::ThenPopulateRandGaussian(double mean, double sd, CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values)); } else { SetError(); - LOG(INFO) << "attempting to perform RNG operation using StreamExecutor " - "without RNG support."; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform RNG operation using StreamExecutor" + " without RNG support."; } } return *this; @@ -4730,8 +4919,9 @@ Stream &Stream::ThenPopulateRandUniform(DeviceMemory<double> *values) { CheckError(rng->DoPopulateRandUniform(this, values)); } else { SetError(); - LOG(INFO) << "attempting to perform RNG operation using StreamExecutor " - "without RNG support."; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform RNG operation using StreamExecutor" + " without RNG support."; } } return *this; @@ -4746,8 +4936,9 @@ Stream &Stream::ThenPopulateRandUniform( CheckError(rng->DoPopulateRandUniform(this, values)); } else { SetError(); - LOG(INFO) << "attempting to perform RNG operation using StreamExecutor " - "without RNG support."; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform RNG operation using StreamExecutor" + " without RNG support."; } } return *this; @@ -4762,9 +4953,9 @@ Stream &Stream::ThenPopulateRandUniform( CheckError(rng->DoPopulateRandUniform(this, values)); } else { SetError(); - LOG(INFO) << "stream " << this - << " attempting to perform RNG operation using StreamExecutor " - "without RNG support."; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform RNG operation using StreamExecutor" + " without RNG support."; } } return *this; @@ -4777,7 +4968,7 @@ Stream &Stream::ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src, if (ok()) { CheckError(parent_->Memcpy(this, host_dst, gpu_src, size)); } else { - LOG(INFO) << "stream " << this + LOG(INFO) << DebugStreamPointers() << " did not memcpy device-to-host; source: " << gpu_src.opaque(); } return *this; @@ -4790,7 +4981,7 @@ Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src, if (ok()) { CheckError(parent_->Memcpy(this, gpu_dst, host_src, size)); } else { - LOG(INFO) << "stream " << this + LOG(INFO) << DebugStreamPointers() << " did not memcpy host-to-device; source: " << host_src; } return *this; @@ -4803,7 +4994,7 @@ Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst, if (ok()) { CheckError(parent_->MemcpyDeviceToDevice(this, gpu_dst, gpu_src, size)); } else { - LOG(INFO) << "stream " << this + LOG(INFO) << DebugStreamPointers() << " did not memcpy gpu-to-gpu; source: " << &gpu_src; } return *this; @@ -4815,7 +5006,7 @@ Stream &Stream::ThenMemZero(DeviceMemoryBase *location, uint64 size) { if (ok()) { CheckError(parent_->MemZero(this, location, size)); } else { - LOG(INFO) << "stream " << this + LOG(INFO) << DebugStreamPointers() << " did not memzero GPU location; source: " << location; } return *this; @@ -4828,7 +5019,7 @@ Stream &Stream::ThenMemset32(DeviceMemoryBase *location, uint32 pattern, if (ok()) { CheckError(parent_->Memset32(this, location, pattern, size)); } else { - LOG(INFO) << "stream " << this + LOG(INFO) << DebugStreamPointers() << " did not memset GPU location; source: " << location << "; size: " << size << "; pattern: " << std::hex << pattern; } @@ -5097,12 +5288,25 @@ Stream &Stream::ThenDoHostCallback(std::function<void()> callback) { if (ok()) { CheckError(parent_->HostCallback(this, callback)); } else { - LOG(INFO) << "stream " << this + LOG(INFO) << DebugStreamPointers() << " was in error state before adding host callback"; } return *this; } +Stream &Stream::ThenDoHostCallbackWithStatus( + std::function<port::Status()> callback) { + VLOG_CALL(PARAM(callback)); + + if (ok()) { + CheckError(parent_->HostCallback(this, std::move(callback))); + } else { + LOG(WARNING) << "stream " << DebugStreamPointers() + << " was in error state before adding host callback"; + } + return *this; +} + Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<std::complex<float>> &input, DeviceMemory<std::complex<float>> *output) { @@ -5113,8 +5317,9 @@ Stream &Stream::ThenFft(fft::Plan *plan, CheckError(fft->DoFft(this, plan, input, output)); } else { SetError(); - LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " - "without FFT support"; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform FFT operation using StreamExecutor" + " without FFT support"; } } return *this; @@ -5130,8 +5335,9 @@ Stream &Stream::ThenFft(fft::Plan *plan, CheckError(fft->DoFft(this, plan, input, output)); } else { SetError(); - LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " - "without FFT support"; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform FFT operation using StreamExecutor" + " without FFT support"; } } return *this; @@ -5146,8 +5352,9 @@ Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<float> &input, CheckError(fft->DoFft(this, plan, input, output)); } else { SetError(); - LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " - "without FFT support"; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform FFT operation using StreamExecutor" + " without FFT support"; } } return *this; @@ -5162,8 +5369,9 @@ Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<double> &input, CheckError(fft->DoFft(this, plan, input, output)); } else { SetError(); - LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " - "without FFT support"; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform FFT operation using StreamExecutor" + " without FFT support"; } } return *this; @@ -5179,8 +5387,9 @@ Stream &Stream::ThenFft(fft::Plan *plan, CheckError(fft->DoFft(this, plan, input, output)); } else { SetError(); - LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " - "without FFT support"; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform FFT operation using StreamExecutor" + " without FFT support"; } } return *this; @@ -5196,8 +5405,9 @@ Stream &Stream::ThenFft(fft::Plan *plan, CheckError(fft->DoFft(this, plan, input, output)); } else { SetError(); - LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " - "without FFT support"; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform FFT operation using StreamExecutor" + " without FFT support"; } } return *this; @@ -5224,7 +5434,7 @@ port::Status Stream::BlockHostUntilDone() { port::Status status = port::Status( port::error::INTERNAL, "stream did not block host until done; was already in an error state"); - LOG(INFO) << status << " " << this; + LOG(INFO) << DebugStreamPointers() << " " << status; return status; } @@ -5235,4 +5445,10 @@ port::Status Stream::BlockHostUntilDone() { return error; } +string Stream::DebugStreamPointers() const { + // Relies on the ToVlogString(const void*) overload above. + return port::StrCat("[stream=", ToVlogString(this), + ",impl=", ToVlogString(implementation_.get()), "]"); +} + } // namespace stream_executor diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h index e8885e1eb6..e1629b5b30 100644 --- a/tensorflow/stream_executor/stream.h +++ b/tensorflow/stream_executor/stream.h @@ -122,10 +122,14 @@ class Stream { // Get or create a sub-stream from this stream. If there is any sub-stream in // the pool that can be reused then just return this sub-stream. Otherwise // create a new sub-stream. + // + // TODO(b/112196569): The semantics of failed sub-streams is error-prone. Stream *GetOrCreateSubStream() LOCKS_EXCLUDED(mu_); // Return the sub-stream back to the host stream so that it can be reused - // later. + // later. Sub-streams that are !ok() will not be reused. + // + // TODO(b/112196569): The semantics of failed sub-streams is error-prone. void ReturnSubStream(Stream *sub_stream) LOCKS_EXCLUDED(mu_); // Allocate temporary memories. The stream will deallocate them when blocked @@ -629,19 +633,22 @@ class Stream { const dnn::BatchDescriptor &input_dimensions, const DeviceMemory<double> &input_data, const dnn::BatchDescriptor &output_dimensions, - DeviceMemory<double> *output_data); + DeviceMemory<double> *output_data, + ScratchAllocator *workspace_allocator = nullptr); Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions, const dnn::BatchDescriptor &input_dimensions, const DeviceMemory<float> &input_data, const dnn::BatchDescriptor &output_dimensions, - DeviceMemory<float> *output_data); + DeviceMemory<float> *output_data, + ScratchAllocator *workspace_allocator = nullptr); Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions, const dnn::BatchDescriptor &input_dimensions, const DeviceMemory<Eigen::half> &input_data, const dnn::BatchDescriptor &output_dimensions, - DeviceMemory<Eigen::half> *output_data); + DeviceMemory<Eigen::half> *output_data, + ScratchAllocator *workspace_allocator = nullptr); Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions, const dnn::BatchDescriptor &input_dimensions, @@ -649,7 +656,8 @@ class Stream { const dnn::BatchDescriptor &output_dimensions, const DeviceMemory<double> &output_data, const DeviceMemory<double> &input_diff_data, - DeviceMemory<double> *output_diff_data); + DeviceMemory<double> *output_diff_data, + ScratchAllocator *workspace_allocator = nullptr); Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions, const dnn::BatchDescriptor &input_dimensions, @@ -657,7 +665,8 @@ class Stream { const dnn::BatchDescriptor &output_dimensions, const DeviceMemory<float> &output_data, const DeviceMemory<float> &input_diff_data, - DeviceMemory<float> *output_diff_data); + DeviceMemory<float> *output_diff_data, + ScratchAllocator *workspace_allocator = nullptr); Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions, const dnn::BatchDescriptor &input_dimensions, @@ -665,7 +674,8 @@ class Stream { const dnn::BatchDescriptor &output_dimensions, const DeviceMemory<Eigen::half> &output_data, const DeviceMemory<Eigen::half> &input_diff_data, - DeviceMemory<Eigen::half> *output_diff_data); + DeviceMemory<Eigen::half> *output_diff_data, + ScratchAllocator *workspace_allocator = nullptr); Stream &ThenNormalize(const dnn::NormalizeDescriptor &normalize_descriptor, const DeviceMemory<float> &input_data, @@ -684,7 +694,8 @@ class Stream { const DeviceMemory<float> &raw_data, const DeviceMemory<float> &normalized_data, const DeviceMemory<float> &normalized_variable_gradient, - DeviceMemory<float> *raw_variable_gradient); + DeviceMemory<float> *raw_variable_gradient, + ScratchAllocator *workspace_allocator = nullptr); Stream &ThenActivate(dnn::ActivationMode activation_mode, const dnn::BatchDescriptor &dimensions, @@ -1550,6 +1561,38 @@ class Stream { std::complex<double> beta, const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc, int batch_count, ScratchAllocator *scratch_allocator); + Stream &ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, int lda, + int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb, + int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc, + int64 stride_c, int batch_count); + Stream &ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, float alpha, const DeviceMemory<float> &a, int lda, + int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b, + float beta, DeviceMemory<float> *c, int ldc, int64 stride_c, + int batch_count); + Stream &ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, double alpha, const DeviceMemory<double> &a, int lda, + int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b, + double beta, DeviceMemory<double> *c, int ldc, int64 stride_c, + int batch_count); + Stream &ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, std::complex<float> alpha, + const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a, + const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b, + std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, + int64 stride_c, int batch_count); + Stream &ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, std::complex<double> alpha, + const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a, + const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b, + std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc, + int64 stride_c, int batch_count); // See BlasSupport::DoBlasHemm. Stream &ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m, @@ -2002,6 +2045,11 @@ class Stream { // negative effects on performance. Stream &ThenDoHostCallback(std::function<void()> callback); + // Entrains onto the stream a callback to the host (from the device). + // Behaves as ThenDoHostCallback above, but returns a Status instead of void. + // This overload should be preferred if the callback could fail. + Stream &ThenDoHostCallbackWithStatus(std::function<port::Status()> callback); + // Returns the StreamExecutor (parent object) associated with this stream. StreamExecutor *parent() const { CHECK(parent_ != nullptr); @@ -2012,6 +2060,9 @@ class Stream { // with this stream. internal::TemporaryMemoryManager *temporary_memory_manager(); + // Returns a debugging string "[stream=0x...,impl=0x...]". + string DebugStreamPointers() const; + private: friend class host::HostBlas; // for parent_. friend class host::HostFft; // for parent_. diff --git a/tensorflow/stream_executor/stream_executor_internal.cc b/tensorflow/stream_executor/stream_executor_internal.cc index 8297228e6f..7df6a361c6 100644 --- a/tensorflow/stream_executor/stream_executor_internal.cc +++ b/tensorflow/stream_executor/stream_executor_internal.cc @@ -36,5 +36,17 @@ StreamExecutorFactory* MakeOpenCLExecutorImplementation() { StreamExecutorFactory MakeHostExecutorImplementation; +// TODO(b/112125301): Consolodate this down to one implementation of +// HostCallback, taking a callback that returns a Status. +bool StreamExecutorInterface::HostCallback( + Stream* stream, std::function<port::Status()> callback) { + return HostCallback(stream, [callback]() { + port::Status s = callback(); + if (!s.ok()) { + LOG(WARNING) << "HostCallback failed: " << s; + } + }); +} + } // namespace internal } // namespace stream_executor diff --git a/tensorflow/stream_executor/stream_executor_internal.h b/tensorflow/stream_executor/stream_executor_internal.h index 9c989b971d..59a477b5c9 100644 --- a/tensorflow/stream_executor/stream_executor_internal.h +++ b/tensorflow/stream_executor/stream_executor_internal.h @@ -36,20 +36,38 @@ limitations under the License. #include "tensorflow/stream_executor/kernel_cache_config.h" #include "tensorflow/stream_executor/kernel_spec.h" #include "tensorflow/stream_executor/launch_dim.h" +#include "tensorflow/stream_executor/lib/inlined_vector.h" #include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/lib/statusor.h" +#include "tensorflow/stream_executor/module_spec.h" #include "tensorflow/stream_executor/platform.h" #include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/plugin_registry.h" #include "tensorflow/stream_executor/shared_memory_config.h" #include "tensorflow/stream_executor/trace_listener.h" -#include "tensorflow/stream_executor/lib/inlined_vector.h" namespace stream_executor { class Stream; class Timer; +// An opaque handle to a loaded module. +// +// An instance of this is returned from StreamExecutor::GetModule. +class ModuleHandle { + public: + /*implicit*/ ModuleHandle(void *id = nullptr) : id_(id) {} + + // A ModuleHandle with id() == nullptr is an invalid module handle, akin to a + // null pointer. + void *id() const { return id_; } + + explicit operator bool() const { return id() != nullptr; } + + private: + void *id_; +}; + namespace internal { // Platform-dependent interface class for the generic Events interface, in @@ -100,19 +118,20 @@ class StreamInterface { // Default destructor for the abstract interface. virtual ~StreamInterface() {} - // Returns the CUDA stream associated with this platform's stream + // Returns the GPU stream associated with this platform's stream // implementation. // - // WARNING: checks that the underlying platform is, in fact, CUDA, causing a - // fatal error if it is not. This hack is made available solely for use from - // distbelief code, which temporarily has strong ties to CUDA as a platform. - virtual void *CudaStreamHack() { return nullptr; } - - // See the above comment on CudaStreamHack -- this further breaks abstraction - // for Eigen within distbelief, which has strong ties to CUDA as a platform, - // and a historical attachment to a programming model which takes a + // WARNING: checks that the underlying platform is, in fact, CUDA or ROCm, + // causing a fatal error if it is not. This hack is made available solely for + // use from distbelief code, which temporarily has strong ties to CUDA or + // ROCm as a platform. + virtual void *GpuStreamHack() { return nullptr; } + + // See the above comment on GpuStreamHack -- this further breaks abstraction + // for Eigen within distbelief, which has strong ties to CUDA or ROCm as a + // platform, and a historical attachment to a programming model which takes a // stream-slot rather than a stream-value. - virtual void **CudaStreamMemberHack() { return nullptr; } + virtual void **GpuStreamMemberHack() { return nullptr; } private: SE_DISALLOW_COPY_AND_ASSIGN(StreamInterface); @@ -163,6 +182,11 @@ class StreamExecutorInterface { KernelBase *kernel) { return false; } + virtual bool LoadModule(const MultiModuleLoaderSpec &spec, + ModuleHandle *module_handle) { + return false; + } + virtual bool UnloadModule(ModuleHandle module_handle) { return false; } virtual bool Launch(Stream *stream, const ThreadDim &thread_dims, const BlockDim &block_dims, const KernelBase &k, const KernelArgsArrayBase &args) { @@ -212,9 +236,11 @@ class StreamExecutorInterface { virtual bool Memcpy(Stream *stream, DeviceMemoryBase *gpu_dst, const void *host_src, uint64 size) = 0; virtual bool MemcpyDeviceToDevice(Stream *stream, DeviceMemoryBase *gpu_dst, - const DeviceMemoryBase &host_src, + const DeviceMemoryBase &gpu_src, uint64 size) = 0; virtual bool HostCallback(Stream *stream, std::function<void()> callback) = 0; + virtual bool HostCallback(Stream *stream, + std::function<port::Status()> callback); virtual port::Status AllocateEvent(Event *event) = 0; virtual port::Status DeallocateEvent(Event *event) = 0; virtual port::Status RecordEvent(Stream *stream, Event *event) = 0; @@ -246,7 +272,12 @@ class StreamExecutorInterface { // null, however, both of them cannot be null at the same time. To use // constant memory in CUDA, GetSymbol has to be used. Returns true if symbol // is found. - virtual bool GetSymbol(const string& symbol_name, void **mem, size_t *bytes) { + // + // If ModuleHandle is set then we search for `symbol_name` only within the + // module corresponding to `module_handle`. Otherwise all loaded modules are + // searched. + virtual bool GetSymbol(const string &symbol_name, ModuleHandle module_handle, + void **mem, size_t *bytes) { return false; } @@ -324,13 +355,14 @@ class StreamExecutorInterface { virtual std::unique_ptr<StreamInterface> GetStreamImplementation() = 0; virtual std::unique_ptr<TimerInterface> GetTimerImplementation() = 0; - // Returns the CUDA context associated with this StreamExecutor platform - // implementation. + // Returns the CUDA or ROCm context associated with this StreamExecutor + // platform implementation. // - // WARNING: checks that the underlying platform is, in fact, CUDA, causing a - // fatal error if it is not. This hack is made available solely for use from - // distbelief code, which temporarily has strong ties to CUDA as a platform. - virtual void *CudaContextHack() { return nullptr; } + // WARNING: checks that the underlying platform is, in fact, CUDA or ROCm, + // causing a fatal error if it is not. This hack is made available solely for + // use from distbelief code, which temporarily has strong ties to CUDA or ROCm + // as a platform. + virtual void *GpuContextHack() { return nullptr; } private: SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutorInterface); diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc index 000795ff00..9515d8e62a 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.cc +++ b/tensorflow/stream_executor/stream_executor_pimpl.cc @@ -220,6 +220,15 @@ void StreamExecutor::UnloadKernel(const KernelBase *kernel) { implementation_->UnloadKernel(kernel); } +bool StreamExecutor::LoadModule(const MultiModuleLoaderSpec &spec, + ModuleHandle *module_handle) { + return implementation_->LoadModule(spec, module_handle); +} + +bool StreamExecutor::UnloadModule(ModuleHandle module_handle) { + return implementation_->UnloadModule(module_handle); +} + void StreamExecutor::Deallocate(DeviceMemoryBase *mem) { VLOG(1) << "Called StreamExecutor::Deallocate(mem=" << mem->opaque() << ") mem->size()=" << mem->size() << StackTraceIfVLOG10(); @@ -459,9 +468,34 @@ void *StreamExecutor::Allocate(uint64 size) { return buf; } -bool StreamExecutor::GetSymbol(const string &symbol_name, void **mem, +port::StatusOr<DeviceMemoryBase> StreamExecutor::GetUntypedSymbol( + const string &symbol_name, ModuleHandle module_handle) { + // If failed to get the symbol, opaque/bytes are unchanged. Initialize them to + // be nullptr/0 for consistency with DeviceMemory semantics. + void *opaque = nullptr; + size_t bytes = 0; + if (GetSymbol(symbol_name, module_handle, &opaque, &bytes)) { + return DeviceMemoryBase(opaque, bytes); + } + + if (static_cast<bool>(module_handle)) { + return port::Status( + port::error::NOT_FOUND, + port::StrCat("Check if module containing symbol ", symbol_name, + " is loaded (module_handle = ", + reinterpret_cast<uintptr_t>(module_handle.id()), ")")); + } else { + return port::Status( + port::error::NOT_FOUND, + port::StrCat("Check if kernel using the symbol is loaded: ", + symbol_name)); + } +} + +bool StreamExecutor::GetSymbol(const string &symbol_name, + ModuleHandle module_handle, void **mem, size_t *bytes) { - return implementation_->GetSymbol(symbol_name, mem, bytes); + return implementation_->GetSymbol(symbol_name, module_handle, mem, bytes); } void *StreamExecutor::UnifiedMemoryAllocate(uint64 bytes) { @@ -665,6 +699,11 @@ bool StreamExecutor::HostCallback(Stream *stream, return implementation_->HostCallback(stream, std::move(callback)); } +bool StreamExecutor::HostCallback(Stream *stream, + std::function<port::Status()> callback) { + return implementation_->HostCallback(stream, std::move(callback)); +} + port::Status StreamExecutor::AllocateEvent(Event *event) { return implementation_->AllocateEvent(event); } diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h index ad80a1ba25..437f298616 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.h +++ b/tensorflow/stream_executor/stream_executor_pimpl.h @@ -106,6 +106,16 @@ class StreamExecutor { // Releases any state associated with the previously loaded kernel. void UnloadKernel(const KernelBase *kernel); + // Loads a module for the platform this StreamExecutor is acting upon. + // + // `spec` describes the module to be loaded. On success writes the handle for + // the loaded module to `module_handle` and returns true. Else returns false. + bool LoadModule(const MultiModuleLoaderSpec &spec, + ModuleHandle *module_handle); + + // Unloads the module with handle `module_handle`. + bool UnloadModule(ModuleHandle module_handle); + // Synchronously allocates an array on the device of type T with element_count // elements. template <typename T> @@ -169,8 +179,16 @@ class StreamExecutor { // type of symbol and T match. // - Note: symbol_name should include its namespace as well. For example, // pass "nms0::symbol" if referring to nms0::symbol. + // + // If `module_handle` is set then searches only within the module + // corresponding to `module_handle`. template <typename T> - port::StatusOr<DeviceMemory<T>> GetSymbol(const string &symbol_name); + port::StatusOr<DeviceMemory<T>> GetSymbol(const string &symbol_name, + ModuleHandle module_handle = {}); + + // An untyped version of GetSymbol. + port::StatusOr<DeviceMemoryBase> GetUntypedSymbol( + const string &symbol_name, ModuleHandle module_handle = {}); // Deallocate the DeviceMemory previously allocated via this interface. // Deallocation of a nullptr-representative value is permitted. @@ -507,7 +525,8 @@ class StreamExecutor { // Finds and retrieves device memory for the symbol on the underlying // platform. - bool GetSymbol(const string& symbol_name, void **mem, size_t *bytes); + bool GetSymbol(const string &symbol_name, ModuleHandle module_handle, + void **mem, size_t *bytes); // Entrains a memcpy operation onto stream, with a host destination location // host_dst and a device memory source, with target size size. @@ -530,6 +549,11 @@ class StreamExecutor { // See Stream::ThenDoHostCallback for full details. bool HostCallback(Stream *stream, std::function<void()> callback); + // Entrains on a stream a user-specified function to be run on the host. + // See Stream::ThenDoHostCallback for full details. + // This is the preferred form for a callback that may return an error. + bool HostCallback(Stream *stream, std::function<port::Status()> callback); + // Performs platform-specific allocation and initialization of an event. port::Status AllocateEvent(Event *event); @@ -678,6 +702,41 @@ class StreamExecutor { SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutor); }; +// A wrapper around ModuleHandle that uses RAII to manage its lifetime. +class ScopedModuleHandle { + public: + explicit ScopedModuleHandle(StreamExecutor *executor, + ModuleHandle module_handle) + : executor_(executor), module_handle_(module_handle) {} + + ScopedModuleHandle(ScopedModuleHandle &&other) { + executor_ = other.executor_; + module_handle_ = other.module_handle_; + other.executor_ = nullptr; + other.module_handle_ = ModuleHandle(); + } + + ScopedModuleHandle &operator=(ScopedModuleHandle &&other) { + executor_ = other.executor_; + module_handle_ = other.module_handle_; + other.executor_ = nullptr; + other.module_handle_ = ModuleHandle(); + return *this; + } + + ~ScopedModuleHandle() { + if (static_cast<bool>(module_handle_)) { + CHECK(executor_->UnloadModule(module_handle_)); + } + } + + private: + StreamExecutor *executor_; + ModuleHandle module_handle_; + + TF_DISALLOW_COPY_AND_ASSIGN(ScopedModuleHandle); +}; + //////////// // Inlines @@ -690,19 +749,13 @@ inline DeviceMemory<T> StreamExecutor::AllocateArray(uint64 element_count) { template <typename T> inline port::StatusOr<DeviceMemory<T>> StreamExecutor::GetSymbol( - const string &symbol_name) { - // If failed to get the symbol, opaque/bytes are unchanged. Initialize them to - // be nullptr/0 for consistency with DeviceMemory semantics. - void *opaque = nullptr; - size_t bytes = 0; - if (GetSymbol(symbol_name, &opaque, &bytes)) { - CHECK_EQ(bytes % sizeof(T), 0); - return DeviceMemory<T>::MakeFromByteSize(opaque, bytes); + const string &symbol_name, ModuleHandle module_handle) { + port::StatusOr<DeviceMemoryBase> untyped_symbol = + GetUntypedSymbol(symbol_name, module_handle); + if (!untyped_symbol.ok()) { + return untyped_symbol.status(); } - return port::Status( - port::error::NOT_FOUND, - port::StrCat("Check if kernel using the symbol is loaded: ", - symbol_name)); + return DeviceMemory<T>(untyped_symbol.ValueOrDie()); } template <typename ElemT> diff --git a/tensorflow/stream_executor/stream_test.cc b/tensorflow/stream_executor/stream_test.cc new file mode 100644 index 0000000000..cfc051fd09 --- /dev/null +++ b/tensorflow/stream_executor/stream_test.cc @@ -0,0 +1,203 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/stream_executor/stream_executor.h" + +#include "tensorflow/core/platform/test.h" + +namespace stream_executor { +namespace { + +class StreamTest : public ::testing::Test { + protected: + std::unique_ptr<StreamExecutor> NewStreamExecutor() { + Platform* platform = + MultiPlatformManager::PlatformWithName("Host").ConsumeValueOrDie(); + StreamExecutorConfig config(/*ordinal=*/0); + return platform->GetUncachedExecutor(config).ConsumeValueOrDie(); + } +}; + +TEST_F(StreamTest, NoInitNotOk) { + std::unique_ptr<StreamExecutor> executor = NewStreamExecutor(); + Stream stream(executor.get()); + EXPECT_FALSE(stream.ok()); +} + +TEST_F(StreamTest, InitOk) { + std::unique_ptr<StreamExecutor> executor = NewStreamExecutor(); + Stream stream(executor.get()); + stream.Init(); + EXPECT_TRUE(stream.ok()); +} + +TEST_F(StreamTest, OneSubStream) { + std::unique_ptr<StreamExecutor> executor = NewStreamExecutor(); + Stream stream(executor.get()); + stream.Init(); + EXPECT_TRUE(stream.ok()); + + // Get and return a sub-stream. Sub-streams are always initialized. + Stream* sub_stream1 = stream.GetOrCreateSubStream(); + EXPECT_TRUE(sub_stream1->ok()); + stream.ReturnSubStream(sub_stream1); + + // Get and return another sub-stream. + Stream* sub_stream2 = stream.GetOrCreateSubStream(); + EXPECT_TRUE(sub_stream2->ok()); + stream.ReturnSubStream(sub_stream1); + + // The underlying sub-streams should be the same, since sub_stream1 + // was returned before we tried to get sub_stream2. + EXPECT_EQ(sub_stream1, sub_stream2); +} + +TEST_F(StreamTest, TwoSubStreams) { + std::unique_ptr<StreamExecutor> executor = NewStreamExecutor(); + Stream stream(executor.get()); + stream.Init(); + EXPECT_TRUE(stream.ok()); + + // Get two sub-streams. + Stream* sub_stream1 = stream.GetOrCreateSubStream(); + EXPECT_TRUE(sub_stream1->ok()); + Stream* sub_stream2 = stream.GetOrCreateSubStream(); + EXPECT_TRUE(sub_stream2->ok()); + + // The underlying sub-streams should be different, since neither + // sub-stream has been returned. + EXPECT_NE(sub_stream1, sub_stream2); + + // Return sub_stream1 and get sub_stream3, which should be the same. + stream.ReturnSubStream(sub_stream1); + Stream* sub_stream3 = stream.GetOrCreateSubStream(); + EXPECT_TRUE(sub_stream3->ok()); + EXPECT_EQ(sub_stream1, sub_stream3); + EXPECT_NE(sub_stream2, sub_stream3); + + // Return sub_stream2 and get sub_stream4, which should be the same. + stream.ReturnSubStream(sub_stream2); + Stream* sub_stream4 = stream.GetOrCreateSubStream(); + EXPECT_TRUE(sub_stream4->ok()); + EXPECT_EQ(sub_stream2, sub_stream4); + EXPECT_NE(sub_stream3, sub_stream4); +} + +TEST_F(StreamTest, FailedSubStreamBeforeReturnNotReused) { + std::unique_ptr<StreamExecutor> executor = NewStreamExecutor(); + Stream stream(executor.get()); + stream.Init(); + EXPECT_TRUE(stream.ok()); + + // Get sub_stream1. + Stream* sub_stream1 = stream.GetOrCreateSubStream(); + EXPECT_TRUE(sub_stream1->ok()); + + // Force an error on sub_stream1; here we call a method that requires DNN + // support, which we know the Host platform doesn't support. + sub_stream1->ThenDepthConcatenate({}, {}, nullptr); + EXPECT_FALSE(sub_stream1->ok()); + + // Return sub_stream1 and get sub_stream2. + stream.ReturnSubStream(sub_stream1); + Stream* sub_stream2 = stream.GetOrCreateSubStream(); + EXPECT_TRUE(sub_stream2->ok()); + + // The underlying sub_streams should be different. They would have been the + // same, but since we forced an error on sub_stream1, it will not be + // re-used. Sadly we can't just check: + // EXPECT_NE(sub_stream1, sub_stream2); + // + // The above should hold logically, but it may fail if the new Stream instance + // allocated for sub_stream2 happens to reside in the same memory address as + // sub_stream1. + // + // The check that sub_stream2->ok() serves as a good-enough check. + + // Return sub_stream2 and get sub_stream3. The previous error on sub_stream1 + // has no effect on these streams, and they are the same. + stream.ReturnSubStream(sub_stream2); + Stream* sub_stream3 = stream.GetOrCreateSubStream(); + EXPECT_TRUE(sub_stream3->ok()); + EXPECT_EQ(sub_stream2, sub_stream3); +} + +TEST_F(StreamTest, FailedSubStreamAfterReturnNotReused) { + std::unique_ptr<StreamExecutor> executor = NewStreamExecutor(); + Stream stream(executor.get()); + stream.Init(); + EXPECT_TRUE(stream.ok()); + + // Get and return sub_stream1. + Stream* sub_stream1 = stream.GetOrCreateSubStream(); + EXPECT_TRUE(sub_stream1->ok()); + stream.ReturnSubStream(sub_stream1); + + // Force an error on sub_stream1; here we call a method that requires DNN + // support, which we know the Host platform doesn't support. + // + // It is a bit weird to use sub_stream1 after it has already been returned. By + // doing this, we're simulating an asynchronous error that occurs during + // execution of the sub_stream, that occurs after the sub_stream is returned. + // + // E.g. the following is a common pattern of usage, where the execution of the + // operations enqueued onto the sub streams may occur after the streams have + // already been returned. + // + // void EnqueueOnSubStreams(Stream* stream) { + // Stream* sub_stream1 = stream.GetOrCreateSubStream(); + // Stream* sub_stream2 = stream.GetOrCreateSubStream(); + // // ... enqueue some operations on the sub streams ... + // stream.ThenWaitFor(sub_stream1).ThenWaitFor(sub_stream2); + // stream.ReturnSubStream(sub_stream1); + // stream.ReturnSubStream(sub_stream2); + // } + // + // Stream* main_stream = ...; + // EnqueueOnSubStreams(main_stream); + // main_stream.BlockHostUntilDone(); + // + // TODO(b/112196569): The semantics of failed sub-streams is error-prone; + // GetOrCreateSubStream can still return a sub-stream that has not encountered + // an error yet, but will encounter one in the future, based on previously + // enqueued operations. + sub_stream1->ThenDepthConcatenate({}, {}, nullptr); + EXPECT_FALSE(sub_stream1->ok()); + + // Get and return sub_stream2. + Stream* sub_stream2 = stream.GetOrCreateSubStream(); + EXPECT_TRUE(sub_stream2->ok()); + + // The underlying streams should be different. They would have been the same, + // but since we forced an error on sub_stream1, it will not be re-used. Sadly + // we can't just check: + // EXPECT_NE(sub_stream1, sub_stream2); + // + // The above should hold logically, but it may fail if the new stream instance + // allocated for sub_stream2 happens to reside in the same memory address as + // sub_stream1. + // + // The check that sub_stream2->ok() serves as a good-enough check. + + // Return sub_stream2 and get sub_stream3. The previous error on sub_stream1 + // has no effect on these streams, and they are the same. + stream.ReturnSubStream(sub_stream2); + Stream* sub_stream3 = stream.GetOrCreateSubStream(); + EXPECT_TRUE(sub_stream3->ok()); + EXPECT_EQ(sub_stream2, sub_stream3); +} + +} // namespace +} // namespace stream_executor |