diff options
author | Matt Conley <mconley@nvidia.com> | 2018-08-28 18:55:51 -0700 |
---|---|---|
committer | Matt Conley <mconley@nvidia.com> | 2018-08-28 18:55:51 -0700 |
commit | e93a9f9ccfd9c7a2419bf3fc1d7866765bbcfce3 (patch) | |
tree | 64911ea09beae2cd57365da73ca03c3d805665db /tensorflow/stream_executor/cuda | |
parent | 2e7352e57c541908cd700bb0fe53a04b456392c9 (diff) |
Update GPU occupancy checking to utilize CUDA's occupancy calculator functions
-Replace references to the UnqueryableDeviceParams struct with calls to CUDA's built-in occupancy calculation functions
-Update calls to the occupancy checking functions with the new changes
-Changes should provide more long-term reliability and will remove the need to manually update hardcoded data values for new GPU architectures
Diffstat (limited to 'tensorflow/stream_executor/cuda')
-rw-r--r-- | tensorflow/stream_executor/cuda/cuda_gpu_executor.cc | 192 |
1 files changed, 19 insertions, 173 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc index e30f50ea2a..39b0696c93 100644 --- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc +++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc @@ -467,33 +467,26 @@ void CUDAExecutor::VlogOccupancyInfo(const KernelBase &kernel, return; } + int block_size = thread_dims.x * thread_dims.y * thread_dims.z; + const DeviceDescription &device_description = kernel.parent()->GetDeviceDescription(); - uint64 blocks_per_sm = CalculateOccupancy( - device_description, regs_per_thread, smem_per_block, thread_dims); - VLOG(2) << "Resident blocks per SM is " << blocks_per_sm; + const CUDAKernel* cuda_kernel = AsCUDAKernel(&kernel); + CUfunction cufunc = cuda_kernel->AsCUDAFunctionValue(); - // To increase occupancy, there must be a sufficient number of blocks - // available to spread across the sm's at this new improved occupancy level. - int multiprocessor_count = device_description.core_count(); - int block_count = block_dims.x * block_dims.y * block_dims.z; - int available_blocks_per_sm = - port::MathUtil::CeilOfRatio(block_count, multiprocessor_count); - if (available_blocks_per_sm <= static_cast<int64>(blocks_per_sm)) { - VLOG(2) << "Occupancy is limited by number of blocks available per sm."; - return; - } + int blocks_per_sm = CalculateOccupancy(device_description, regs_per_thread, + smem_per_block, thread_dims, cufunc); + VLOG(2) << "Resident blocks per SM is " << blocks_per_sm; - uint64 improved_regs_per_thread = CalculateRegisterLimitForTargetOccupancy( - device_description, smem_per_block, thread_dims, blocks_per_sm + 1); - if (improved_regs_per_thread != 0) { - VLOG(2) << "Reducing register usage from " << regs_per_thread - << " to " << improved_regs_per_thread - << " could increase resident blocks per SM by one."; - } else { - VLOG(2) << "Resident blocks per SM cannot be increased by reducing " - "register usage."; + int suggested_threads = + CompareOccupancy(&blocks_per_sm, device_description, regs_per_thread, + smem_per_block, thread_dims, cufunc); + if (suggested_threads != 0) { + VLOG(2) << "The cuda occupancy calculator reccommends using " + << suggested_threads + << " threads per block to acheive an occupancy of " << blocks_per_sm + << " blocks per SM."; } } @@ -980,144 +973,6 @@ static int TryToReadNumaNode(const string &pci_bus_id, int device_ordinal) { #endif } -// Set of compute capability specific device parameters that cannot be -// queried from the driver API. These values instead are baked into a -// lookup table indexed by compute capability version. -struct UnqueryableDeviceParams { - int cc_major; - int cc_minor; - uint64 blocks_per_core_limit; - uint64 registers_per_core_limit; - uint64 registers_per_thread_limit; - uint64 warp_alloc_granularity; - uint64 register_alloc_granularity; - uint64 shared_memory_alloc_granularity; -}; - -// http://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#compute-capabilities -// https://developer.download.nvidia.com/compute/cuda/CUDA_Occupancy_calculator.xls -static const UnqueryableDeviceParams kAllUnqueryableDeviceParams[] = { - { - 2, 0, // compute capability (2.0) - 8, // blocks_per_core_limit - 32 * 1024, // registers_per_core_limit - 63, // registers_per_thread_limit - 2, // warp_alloc_granularity - 64, // register_alloc_granularity - 128, // shared_memory_alloc_granularity - }, - { - 2, 1, // compute capability (2.1) - 8, // blocks_per_core_limit - 32 * 1024, // registers_per_core_limit - 63, // registers_per_thread_limit - 2, // warp_alloc_granularity - 64, // register_alloc_granularity - 128, // shared_memory_alloc_granularity - }, - { - 3, 0, // compute capability (3.0) - 16, // blocks_per_core_limit - 64 * 1024, // registers_per_core_limit - 63, // registers_per_thread_limit - 4, // warp_alloc_granularity - 256, // register_alloc_granularity - 256, // shared_memory_alloc_granularity - }, - { - 3, 2, // compute capability (3.2) - 16, // blocks_per_core_limit - 64 * 1024, // registers_per_core_limit - 255, // registers_per_thread_limit - 4, // warp_alloc_granularity - 256, // register_alloc_granularity - 256, // shared_memory_alloc_granularity - }, - { - 3, 5, // compute capability (3.5) - 16, // blocks_per_core_limit - 64 * 1024, // registers_per_core_limit - 255, // registers_per_thread_limit - 4, // warp_alloc_granularity - 256, // register_alloc_granularity - 256, // shared_memory_alloc_granularity - }, - { - 3, 7, // compute capability (3.7) - 16, // blocks_per_core_limit - 128 * 1024, // registers_per_core_limit - 255, // registers_per_thread_limit - 4, // warp_alloc_granularity - 256, // register_alloc_granularity - 256, // shared_memory_alloc_granularity - }, - { - 5, 0, // compute capability (5.0) - 32, // blocks_per_core_limit - 64 * 1024, // registers_per_core_limit - 255, // registers_per_thread_limit - 4, // warp_alloc_granularity - 256, // register_alloc_granularity - 256, // shared_memory_alloc_granularity - }, - { - 5, 2, // compute capability (5.2) - 32, // blocks_per_core_limit - 64 * 1024, // registers_per_core_limit - 255, // registers_per_thread_limit - 4, // warp_alloc_granularity - 256, // register_alloc_granularity - 256, // shared_memory_alloc_granularity - }, - { - 5, 3, // compute capability (5.3) - 32, // blocks_per_core_limit - 64 * 1024, // registers_per_core_limit - 255, // registers_per_thread_limit - 4, // warp_alloc_granularity - 256, // register_alloc_granularity - 256, // shared_memory_alloc_granularity - }, - { - 6, 0, // compute capability (6.0) - 32, // blocks_per_core_limit - 64 * 1024, // registers_per_core_limit - 255, // registers_per_thread_limit - 2, // warp_alloc_granularity - 256, // register_alloc_granularity - 256, // shared_memory_alloc_granularity - }, - { - 6, 1, // compute capability (6.1) - 32, // blocks_per_core_limit - 64 * 1024, // registers_per_core_limit - 255, // registers_per_thread_limit - 4, // warp_alloc_granularity - 256, // register_alloc_granularity - 256, // shared_memory_alloc_granularity - }, - { - 6, 2, // compute capability (6.2) - 32, // blocks_per_core_limit - 64 * 1024, // registers_per_core_limit - 255, // registers_per_thread_limit - 4, // warp_alloc_granularity - 256, // register_alloc_granularity - 256, // shared_memory_alloc_granularity - }, - // TODO(jlebar): Confirm the alloc granularity values for sm_70. These are - // not published in the spreadsheet linked above. Currently we guess that - // they're the same as sm_60. - { - 7, 0, // compute capability (7.0) - 32, // blocks_per_core_limit - 64 * 1024, // registers_per_core_limit - 255, // registers_per_thread_limit - 2, // warp_alloc_granularity - 256, // register_alloc_granularity - 256, // shared_memory_alloc_granularity - }, -}; DeviceDescription *CUDAExecutor::PopulateDeviceDescription() const { internal::DeviceDescriptionBuilder builder; @@ -1193,19 +1048,6 @@ DeviceDescription *CUDAExecutor::PopulateDeviceDescription() const { builder.set_name(device_name); } - for (size_t i = 0; i < TF_ARRAYSIZE(kAllUnqueryableDeviceParams); i++) { - const auto ¶ms = kAllUnqueryableDeviceParams[i]; - if (params.cc_major == cc_major_ && params.cc_minor == cc_minor_) { - builder.set_blocks_per_core_limit(params.blocks_per_core_limit); - builder.set_registers_per_core_limit(params.registers_per_core_limit); - builder.set_registers_per_thread_limit(params.registers_per_thread_limit); - builder.set_warp_alloc_granularity(params.warp_alloc_granularity); - builder.set_register_alloc_granularity(params.register_alloc_granularity); - builder.set_shared_memory_alloc_granularity( - params.shared_memory_alloc_granularity); - } - } - builder.set_platform_version( port::StrCat("Compute Capability ", cc_major_, ".", cc_minor_)); @@ -1227,6 +1069,10 @@ DeviceDescription *CUDAExecutor::PopulateDeviceDescription() const { CUDADriver::GetMaxRegistersPerBlock(device_).ValueOrDie()); builder.set_threads_per_warp( CUDADriver::GetThreadsPerWarp(device_).ValueOrDie()); + builder.set_registers_per_core_limit( + CUDADriver::GetDeviceAttribute( + CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_MULTIPROCESSOR, device_) + .ValueOrDie()); auto built = builder.Build(); return built.release(); |