From e93a9f9ccfd9c7a2419bf3fc1d7866765bbcfce3 Mon Sep 17 00:00:00 2001 From: Matt Conley Date: Tue, 28 Aug 2018 18:55:51 -0700 Subject: Update GPU occupancy checking to utilize CUDA's occupancy calculator functions -Replace references to the UnqueryableDeviceParams struct with calls to CUDA's built-in occupancy calculation functions -Update calls to the occupancy checking functions with the new changes -Changes should provide more long-term reliability and will remove the need to manually update hardcoded data values for new GPU architectures --- .../stream_executor/cuda/cuda_gpu_executor.cc | 192 ++------------------- tensorflow/stream_executor/device_description.cc | 98 +++-------- tensorflow/stream_executor/device_description.h | 73 ++------ 3 files changed, 59 insertions(+), 304 deletions(-) (limited to 'tensorflow/stream_executor') diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc index e30f50ea2a..39b0696c93 100644 --- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc +++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc @@ -467,33 +467,26 @@ void CUDAExecutor::VlogOccupancyInfo(const KernelBase &kernel, return; } + int block_size = thread_dims.x * thread_dims.y * thread_dims.z; + const DeviceDescription &device_description = kernel.parent()->GetDeviceDescription(); - uint64 blocks_per_sm = CalculateOccupancy( - device_description, regs_per_thread, smem_per_block, thread_dims); - VLOG(2) << "Resident blocks per SM is " << blocks_per_sm; + const CUDAKernel* cuda_kernel = AsCUDAKernel(&kernel); + CUfunction cufunc = cuda_kernel->AsCUDAFunctionValue(); - // To increase occupancy, there must be a sufficient number of blocks - // available to spread across the sm's at this new improved occupancy level. - int multiprocessor_count = device_description.core_count(); - int block_count = block_dims.x * block_dims.y * block_dims.z; - int available_blocks_per_sm = - port::MathUtil::CeilOfRatio(block_count, multiprocessor_count); - if (available_blocks_per_sm <= static_cast(blocks_per_sm)) { - VLOG(2) << "Occupancy is limited by number of blocks available per sm."; - return; - } + int blocks_per_sm = CalculateOccupancy(device_description, regs_per_thread, + smem_per_block, thread_dims, cufunc); + VLOG(2) << "Resident blocks per SM is " << blocks_per_sm; - uint64 improved_regs_per_thread = CalculateRegisterLimitForTargetOccupancy( - device_description, smem_per_block, thread_dims, blocks_per_sm + 1); - if (improved_regs_per_thread != 0) { - VLOG(2) << "Reducing register usage from " << regs_per_thread - << " to " << improved_regs_per_thread - << " could increase resident blocks per SM by one."; - } else { - VLOG(2) << "Resident blocks per SM cannot be increased by reducing " - "register usage."; + int suggested_threads = + CompareOccupancy(&blocks_per_sm, device_description, regs_per_thread, + smem_per_block, thread_dims, cufunc); + if (suggested_threads != 0) { + VLOG(2) << "The cuda occupancy calculator reccommends using " + << suggested_threads + << " threads per block to acheive an occupancy of " << blocks_per_sm + << " blocks per SM."; } } @@ -980,144 +973,6 @@ static int TryToReadNumaNode(const string &pci_bus_id, int device_ordinal) { #endif } -// Set of compute capability specific device parameters that cannot be -// queried from the driver API. These values instead are baked into a -// lookup table indexed by compute capability version. -struct UnqueryableDeviceParams { - int cc_major; - int cc_minor; - uint64 blocks_per_core_limit; - uint64 registers_per_core_limit; - uint64 registers_per_thread_limit; - uint64 warp_alloc_granularity; - uint64 register_alloc_granularity; - uint64 shared_memory_alloc_granularity; -}; - -// http://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#compute-capabilities -// https://developer.download.nvidia.com/compute/cuda/CUDA_Occupancy_calculator.xls -static const UnqueryableDeviceParams kAllUnqueryableDeviceParams[] = { - { - 2, 0, // compute capability (2.0) - 8, // blocks_per_core_limit - 32 * 1024, // registers_per_core_limit - 63, // registers_per_thread_limit - 2, // warp_alloc_granularity - 64, // register_alloc_granularity - 128, // shared_memory_alloc_granularity - }, - { - 2, 1, // compute capability (2.1) - 8, // blocks_per_core_limit - 32 * 1024, // registers_per_core_limit - 63, // registers_per_thread_limit - 2, // warp_alloc_granularity - 64, // register_alloc_granularity - 128, // shared_memory_alloc_granularity - }, - { - 3, 0, // compute capability (3.0) - 16, // blocks_per_core_limit - 64 * 1024, // registers_per_core_limit - 63, // registers_per_thread_limit - 4, // warp_alloc_granularity - 256, // register_alloc_granularity - 256, // shared_memory_alloc_granularity - }, - { - 3, 2, // compute capability (3.2) - 16, // blocks_per_core_limit - 64 * 1024, // registers_per_core_limit - 255, // registers_per_thread_limit - 4, // warp_alloc_granularity - 256, // register_alloc_granularity - 256, // shared_memory_alloc_granularity - }, - { - 3, 5, // compute capability (3.5) - 16, // blocks_per_core_limit - 64 * 1024, // registers_per_core_limit - 255, // registers_per_thread_limit - 4, // warp_alloc_granularity - 256, // register_alloc_granularity - 256, // shared_memory_alloc_granularity - }, - { - 3, 7, // compute capability (3.7) - 16, // blocks_per_core_limit - 128 * 1024, // registers_per_core_limit - 255, // registers_per_thread_limit - 4, // warp_alloc_granularity - 256, // register_alloc_granularity - 256, // shared_memory_alloc_granularity - }, - { - 5, 0, // compute capability (5.0) - 32, // blocks_per_core_limit - 64 * 1024, // registers_per_core_limit - 255, // registers_per_thread_limit - 4, // warp_alloc_granularity - 256, // register_alloc_granularity - 256, // shared_memory_alloc_granularity - }, - { - 5, 2, // compute capability (5.2) - 32, // blocks_per_core_limit - 64 * 1024, // registers_per_core_limit - 255, // registers_per_thread_limit - 4, // warp_alloc_granularity - 256, // register_alloc_granularity - 256, // shared_memory_alloc_granularity - }, - { - 5, 3, // compute capability (5.3) - 32, // blocks_per_core_limit - 64 * 1024, // registers_per_core_limit - 255, // registers_per_thread_limit - 4, // warp_alloc_granularity - 256, // register_alloc_granularity - 256, // shared_memory_alloc_granularity - }, - { - 6, 0, // compute capability (6.0) - 32, // blocks_per_core_limit - 64 * 1024, // registers_per_core_limit - 255, // registers_per_thread_limit - 2, // warp_alloc_granularity - 256, // register_alloc_granularity - 256, // shared_memory_alloc_granularity - }, - { - 6, 1, // compute capability (6.1) - 32, // blocks_per_core_limit - 64 * 1024, // registers_per_core_limit - 255, // registers_per_thread_limit - 4, // warp_alloc_granularity - 256, // register_alloc_granularity - 256, // shared_memory_alloc_granularity - }, - { - 6, 2, // compute capability (6.2) - 32, // blocks_per_core_limit - 64 * 1024, // registers_per_core_limit - 255, // registers_per_thread_limit - 4, // warp_alloc_granularity - 256, // register_alloc_granularity - 256, // shared_memory_alloc_granularity - }, - // TODO(jlebar): Confirm the alloc granularity values for sm_70. These are - // not published in the spreadsheet linked above. Currently we guess that - // they're the same as sm_60. - { - 7, 0, // compute capability (7.0) - 32, // blocks_per_core_limit - 64 * 1024, // registers_per_core_limit - 255, // registers_per_thread_limit - 2, // warp_alloc_granularity - 256, // register_alloc_granularity - 256, // shared_memory_alloc_granularity - }, -}; DeviceDescription *CUDAExecutor::PopulateDeviceDescription() const { internal::DeviceDescriptionBuilder builder; @@ -1193,19 +1048,6 @@ DeviceDescription *CUDAExecutor::PopulateDeviceDescription() const { builder.set_name(device_name); } - for (size_t i = 0; i < TF_ARRAYSIZE(kAllUnqueryableDeviceParams); i++) { - const auto ¶ms = kAllUnqueryableDeviceParams[i]; - if (params.cc_major == cc_major_ && params.cc_minor == cc_minor_) { - builder.set_blocks_per_core_limit(params.blocks_per_core_limit); - builder.set_registers_per_core_limit(params.registers_per_core_limit); - builder.set_registers_per_thread_limit(params.registers_per_thread_limit); - builder.set_warp_alloc_granularity(params.warp_alloc_granularity); - builder.set_register_alloc_granularity(params.register_alloc_granularity); - builder.set_shared_memory_alloc_granularity( - params.shared_memory_alloc_granularity); - } - } - builder.set_platform_version( port::StrCat("Compute Capability ", cc_major_, ".", cc_minor_)); @@ -1227,6 +1069,10 @@ DeviceDescription *CUDAExecutor::PopulateDeviceDescription() const { CUDADriver::GetMaxRegistersPerBlock(device_).ValueOrDie()); builder.set_threads_per_warp( CUDADriver::GetThreadsPerWarp(device_).ValueOrDie()); + builder.set_registers_per_core_limit( + CUDADriver::GetDeviceAttribute( + CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_MULTIPROCESSOR, device_) + .ValueOrDie()); auto built = builder.Build(); return built.release(); diff --git a/tensorflow/stream_executor/device_description.cc b/tensorflow/stream_executor/device_description.cc index 8ca0677f8a..df52ce6cce 100644 --- a/tensorflow/stream_executor/device_description.cc +++ b/tensorflow/stream_executor/device_description.cc @@ -37,16 +37,11 @@ DeviceDescription::DeviceDescription() kUninitializedUint64), block_dim_limit_(kUninitializedUint64, kUninitializedUint64, kUninitializedUint64), - blocks_per_core_limit_(kUninitializedUint64), threads_per_core_limit_(kUninitializedUint64), threads_per_block_limit_(kUninitializedUint64), threads_per_warp_(kUninitializedUint64), registers_per_core_limit_(kUninitializedUint64), registers_per_block_limit_(kUninitializedUint64), - registers_per_thread_limit_(kUninitializedUint64), - warp_alloc_granularity_(1), - register_alloc_granularity_(1), - shared_memory_alloc_granularity_(1), device_address_bits_(kUninitializedUint64), device_memory_size_(kUninitializedUint64), memory_bandwidth_(kUninitializedUint64), @@ -162,75 +157,36 @@ static uint64 RoundDown(uint64 value, uint64 n) { return port::MathUtil::FloorOfRatio(value, n) * n; } -uint64 CalculateOccupancy(const DeviceDescription &device_description, - uint64 registers_per_thread, - uint64 shared_memory_per_block, - const ThreadDim &thread_dims) { - // Don't try to compute occupancy if necessary values are not initialized. - uint64 required_fields[] = { device_description.registers_per_thread_limit(), - device_description.threads_per_warp(), - device_description.warp_alloc_granularity(), - device_description.register_alloc_granularity(), - device_description.registers_per_block_limit(), - device_description.shared_memory_per_core(), - device_description.blocks_per_core_limit() }; - for (auto value : required_fields) { - if (value == kUninitializedUint64) { - return 0; - } - } - - if (registers_per_thread > device_description.registers_per_thread_limit()) { - return 0; - } - - uint64 warps_per_block = - port::MathUtil::CeilOfRatio(thread_dims.x * thread_dims.y * thread_dims.z, - device_description.threads_per_warp()); - - // Warp resources are allocated at a particular granularity. This value is - // the effective number of warps for resource allocation purposes. - uint64 alloc_warps_per_block = - RoundUp(warps_per_block, device_description.warp_alloc_granularity()); - - uint64 alloc_regs_per_warp = - RoundUp(device_description.threads_per_warp() * registers_per_thread, - device_description.register_alloc_granularity()); - uint64 regs_per_block = alloc_warps_per_block * alloc_regs_per_warp; - uint64 reg_limit = - device_description.registers_per_block_limit() / regs_per_block; - - uint64 alloc_smem_per_block = RoundUp( - shared_memory_per_block, - device_description.shared_memory_alloc_granularity()); - uint64 smem_limit = alloc_smem_per_block > 0 ? - device_description.shared_memory_per_core() / alloc_smem_per_block : - device_description.blocks_per_core_limit(); - - uint64 thread_limit = device_description.threads_per_core_limit() - / (warps_per_block * device_description.threads_per_warp()); - - return std::min({ device_description.blocks_per_core_limit(), - reg_limit, smem_limit, thread_limit }); +int CalculateOccupancy(const DeviceDescription& device_description, + uint64 registers_per_thread, + uint64 shared_memory_per_block, + const ThreadDim& thread_dims, CUfunction func) { + int suggested_blocks = 0; + int suggested_threads = 0; + CUresult err = + cuOccupancyMaxPotentialBlockSize(&suggested_blocks, &suggested_threads, + func, NULL, shared_memory_per_block, 0); + CHECK_EQ(err, CUDA_SUCCESS); + return suggested_blocks; } -uint64 CalculateRegisterLimitForTargetOccupancy( - const DeviceDescription &device_description, uint64 shared_memory_per_block, - const ThreadDim &thread_dims, uint64 target_blocks_per_core) { - // Linear search from maximum number of registers down until the target - // blocks per SM is found. - // TODO(meheff): Compute this using a closed form solution. - int reg_step = device_description.register_alloc_granularity() / - device_description.threads_per_warp(); - for (int r = device_description.registers_per_thread_limit(); r > 0; - r = RoundDown(r - 1, reg_step)) { - uint64 occupancy = CalculateOccupancy( - device_description, r, shared_memory_per_block, thread_dims); - if (occupancy >= target_blocks_per_core) { - return r; - } +int CompareOccupancy(int* initial_blocks, + const DeviceDescription& device_description, + uint64 registers_per_thread, + uint64 shared_memory_per_block, + const ThreadDim& thread_dims, CUfunction func) { + int suggested_blocks = 0; + int suggested_threads = 0; + CUresult err = + cuOccupancyMaxPotentialBlockSize(&suggested_blocks, &suggested_threads, + func, NULL, shared_memory_per_block, 0); + CHECK_EQ(err, CUDA_SUCCESS); + if (suggested_blocks > *initial_blocks) { + *initial_blocks = suggested_blocks; + return suggested_threads; + } else { + return 0; } - return 0; } } // namespace stream_executor diff --git a/tensorflow/stream_executor/device_description.h b/tensorflow/stream_executor/device_description.h index 7f99d81ef3..d335b9b875 100644 --- a/tensorflow/stream_executor/device_description.h +++ b/tensorflow/stream_executor/device_description.h @@ -24,6 +24,7 @@ limitations under the License. #include #include "tensorflow/stream_executor/platform/port.h" +#include "tensorflow/stream_executor/cuda/cuda_driver.h" #include "tensorflow/stream_executor/launch_dim.h" #include "tensorflow/stream_executor/platform/port.h" @@ -79,10 +80,6 @@ class DeviceDescription { // legitimate kernel launch request. const BlockDim &block_dim_limit() const { return block_dim_limit_; } - // Returns the limit on the number of simultaneously resident blocks - // on a multiprocessor. - uint64 blocks_per_core_limit() const { return blocks_per_core_limit_; } - // Returns the limit on the total number of threads that can be launched in a // single block; i.e. the limit on x * y * z dimensions of a ThreadDim. // This limit affects what constitutes a legitimate kernel launch request. @@ -110,27 +107,6 @@ class DeviceDescription { return registers_per_block_limit_; } - // Returns the limit on the total number of registers that can be - // allocated to a thread. - const uint64 ®isters_per_thread_limit() const { - return registers_per_thread_limit_; - } - - // Returns the granularity at which warps are allocated resources. - const uint64 &warp_alloc_granularity() const { - return warp_alloc_granularity_; - } - - // Returns the granularity at which registers are allocated to warps. - const uint64 ®ister_alloc_granularity() const { - return register_alloc_granularity_; - } - - // Returns the granularity at which shared memory is allocated to warps. - const uint64 &shared_memory_alloc_granularity() const { - return shared_memory_alloc_granularity_; - } - // Returns the number of address bits available to kernel code running on the // platform. This affects things like the maximum allocation size and perhaps // types used in kernel code such as size_t. @@ -200,19 +176,12 @@ class DeviceDescription { ThreadDim thread_dim_limit_; BlockDim block_dim_limit_; - uint64 blocks_per_core_limit_; - uint64 threads_per_core_limit_; uint64 threads_per_block_limit_; uint64 threads_per_warp_; uint64 registers_per_core_limit_; uint64 registers_per_block_limit_; - uint64 registers_per_thread_limit_; - - uint64 warp_alloc_granularity_; - uint64 register_alloc_granularity_; - uint64 shared_memory_alloc_granularity_; uint64 device_address_bits_; uint64 device_memory_size_; @@ -270,10 +239,6 @@ class DeviceDescriptionBuilder { device_description_->block_dim_limit_ = value; } - void set_blocks_per_core_limit(uint64 value) { - device_description_->blocks_per_core_limit_ = value; - } - void set_threads_per_core_limit(uint64 value) { device_description_->threads_per_core_limit_ = value; } @@ -290,19 +255,6 @@ class DeviceDescriptionBuilder { void set_registers_per_block_limit(uint64 value) { device_description_->registers_per_block_limit_ = value; } - void set_registers_per_thread_limit(uint64 value) { - device_description_->registers_per_thread_limit_ = value; - } - - void set_warp_alloc_granularity(uint64 value) { - device_description_->warp_alloc_granularity_ = value; - } - void set_register_alloc_granularity(uint64 value) { - device_description_->register_alloc_granularity_ = value; - } - void set_shared_memory_alloc_granularity(uint64 value) { - device_description_->shared_memory_alloc_granularity_ = value; - } void set_device_address_bits(uint64 value) { device_description_->device_address_bits_ = value; @@ -375,17 +327,18 @@ void CalculateDimensionality(const DeviceDescription &device_description, // Compute and return maximum blocks per core (occupancy) based on the // device description, some kernel characteristics and the number of threads per // block. If unable to compute occupancy, zero is returned. -uint64 CalculateOccupancy(const DeviceDescription &device_description, - uint64 registers_per_thread, - uint64 shared_memory_per_block, - const ThreadDim &thread_dims); - -// Compute and return the maximum number of registers per thread which -// achieves the target occupancy. If the target is not possible then -// zero is returned. -uint64 CalculateRegisterLimitForTargetOccupancy( - const DeviceDescription &device_description, uint64 shared_memory_per_block, - const ThreadDim &thread_dims, uint64 target_blocks_per_core); +int CalculateOccupancy(const DeviceDescription& device_description, + uint64 registers_per_thread, + uint64 shared_memory_per_block, + const ThreadDim& thread_dims, CUfunction func); + +// Compute and return the suggested thread count to acheive ideal occupancy. +// If the provided thread dimensions match this number, zero is returned. +int CompareOccupancy(int* initial_blocks, + const DeviceDescription& device_description, + uint64 registers_per_thread, + uint64 shared_memory_per_block, + const ThreadDim& thread_dims, CUfunction func); } // namespace stream_executor -- cgit v1.2.3 From fa20b59b920233d35bb8da3fbc3c234c369a8291 Mon Sep 17 00:00:00 2001 From: Matt Conley Date: Tue, 4 Sep 2018 14:20:40 -0700 Subject: Move CUDA-specific occupancy calculation into proper file -Maintain functionality, just move CalculateOccupancy() and CompareOccupancy() methods from device_description to cuda_gpu_executor -Remove CUDA requirement in general class device_description --- .../stream_executor/cuda/cuda_gpu_executor.cc | 37 ++++++++++++++++++++++ .../stream_executor/cuda/cuda_gpu_executor.h | 11 +++++++ tensorflow/stream_executor/device_description.cc | 32 ------------------- tensorflow/stream_executor/device_description.h | 17 ---------- 4 files changed, 48 insertions(+), 49 deletions(-) (limited to 'tensorflow/stream_executor') diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc index 39b0696c93..458c0e3030 100644 --- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc +++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc @@ -490,6 +490,43 @@ void CUDAExecutor::VlogOccupancyInfo(const KernelBase &kernel, } } +// Compute and return maximum blocks per core (occupancy) based on the +// device description, some kernel characteristics and the number of threads per +// block. If unable to compute occupancy, zero is returned. +int CalculateOccupancy(const DeviceDescription& device_description, + uint64 registers_per_thread, + uint64 shared_memory_per_block, + const ThreadDim& thread_dims, CUfunction func) { + int suggested_blocks = 0; + int suggested_threads = 0; + CUresult err = + cuOccupancyMaxPotentialBlockSize(&suggested_blocks, &suggested_threads, + func, NULL, shared_memory_per_block, 0); + CHECK_EQ(err, CUDA_SUCCESS); + return suggested_blocks; +} + +// Compute and return the suggested thread count to acheive ideal occupancy. +// If the provided thread dimensions match this number, zero is returned. +int CompareOccupancy(int* initial_blocks, + const DeviceDescription& device_description, + uint64 registers_per_thread, + uint64 shared_memory_per_block, + const ThreadDim& thread_dims, CUfunction func) { + int suggested_blocks = 0; + int suggested_threads = 0; + CUresult err = + cuOccupancyMaxPotentialBlockSize(&suggested_blocks, &suggested_threads, + func, NULL, shared_memory_per_block, 0); + CHECK_EQ(err, CUDA_SUCCESS); + if (suggested_blocks > *initial_blocks) { + *initial_blocks = suggested_blocks; + return suggested_threads; + } else { + return 0; + } +} + void *CUDAExecutor::Allocate(uint64 size) { return CUDADriver::DeviceAllocate(context_, size); } diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h index 8a954d5461..e8ebbc3220 100644 --- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h +++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h @@ -70,6 +70,17 @@ class CUDAExecutor : public internal::StreamExecutorInterface { const BlockDim &block_dims, const KernelBase &k, const KernelArgsArrayBase &args) override; + int CalculateOccupancy(const DeviceDescription& device_description, + uint64 registers_per_thread, + uint64 shared_memory_per_block, + const ThreadDim& thread_dims, CUfunction func); + + int CompareOccupancy(int* initial_blocks, + const DeviceDescription& device_description, + uint64 registers_per_thread, + uint64 shared_memory_per_block, + const ThreadDim& thread_dims, CUfunction func); + void *Allocate(uint64 size) override; void *AllocateSubBuffer(DeviceMemoryBase *mem, uint64 offset_bytes, diff --git a/tensorflow/stream_executor/device_description.cc b/tensorflow/stream_executor/device_description.cc index df52ce6cce..726c4adf74 100644 --- a/tensorflow/stream_executor/device_description.cc +++ b/tensorflow/stream_executor/device_description.cc @@ -157,36 +157,4 @@ static uint64 RoundDown(uint64 value, uint64 n) { return port::MathUtil::FloorOfRatio(value, n) * n; } -int CalculateOccupancy(const DeviceDescription& device_description, - uint64 registers_per_thread, - uint64 shared_memory_per_block, - const ThreadDim& thread_dims, CUfunction func) { - int suggested_blocks = 0; - int suggested_threads = 0; - CUresult err = - cuOccupancyMaxPotentialBlockSize(&suggested_blocks, &suggested_threads, - func, NULL, shared_memory_per_block, 0); - CHECK_EQ(err, CUDA_SUCCESS); - return suggested_blocks; -} - -int CompareOccupancy(int* initial_blocks, - const DeviceDescription& device_description, - uint64 registers_per_thread, - uint64 shared_memory_per_block, - const ThreadDim& thread_dims, CUfunction func) { - int suggested_blocks = 0; - int suggested_threads = 0; - CUresult err = - cuOccupancyMaxPotentialBlockSize(&suggested_blocks, &suggested_threads, - func, NULL, shared_memory_per_block, 0); - CHECK_EQ(err, CUDA_SUCCESS); - if (suggested_blocks > *initial_blocks) { - *initial_blocks = suggested_blocks; - return suggested_threads; - } else { - return 0; - } -} - } // namespace stream_executor diff --git a/tensorflow/stream_executor/device_description.h b/tensorflow/stream_executor/device_description.h index d335b9b875..b15ce31216 100644 --- a/tensorflow/stream_executor/device_description.h +++ b/tensorflow/stream_executor/device_description.h @@ -24,7 +24,6 @@ limitations under the License. #include #include "tensorflow/stream_executor/platform/port.h" -#include "tensorflow/stream_executor/cuda/cuda_driver.h" #include "tensorflow/stream_executor/launch_dim.h" #include "tensorflow/stream_executor/platform/port.h" @@ -324,22 +323,6 @@ void CalculateDimensionality(const DeviceDescription &device_description, uint64 element_count, uint64 *threads_per_block, uint64 *block_count); -// Compute and return maximum blocks per core (occupancy) based on the -// device description, some kernel characteristics and the number of threads per -// block. If unable to compute occupancy, zero is returned. -int CalculateOccupancy(const DeviceDescription& device_description, - uint64 registers_per_thread, - uint64 shared_memory_per_block, - const ThreadDim& thread_dims, CUfunction func); - -// Compute and return the suggested thread count to acheive ideal occupancy. -// If the provided thread dimensions match this number, zero is returned. -int CompareOccupancy(int* initial_blocks, - const DeviceDescription& device_description, - uint64 registers_per_thread, - uint64 shared_memory_per_block, - const ThreadDim& thread_dims, CUfunction func); - } // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_DEVICE_DESCRIPTION_H_ -- cgit v1.2.3 From cd6597b8fcd82b51ddb47a297972a1614c2a5d78 Mon Sep 17 00:00:00 2001 From: Matt Conley Date: Tue, 4 Sep 2018 16:17:40 -0700 Subject: Fixed transition typo --- tensorflow/stream_executor/cuda/cuda_gpu_executor.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'tensorflow/stream_executor') diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc index 458c0e3030..a961e9a6c4 100644 --- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc +++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc @@ -493,7 +493,7 @@ void CUDAExecutor::VlogOccupancyInfo(const KernelBase &kernel, // Compute and return maximum blocks per core (occupancy) based on the // device description, some kernel characteristics and the number of threads per // block. If unable to compute occupancy, zero is returned. -int CalculateOccupancy(const DeviceDescription& device_description, +int CUDAExecutor::CalculateOccupancy(const DeviceDescription& device_description, uint64 registers_per_thread, uint64 shared_memory_per_block, const ThreadDim& thread_dims, CUfunction func) { @@ -508,7 +508,7 @@ int CalculateOccupancy(const DeviceDescription& device_description, // Compute and return the suggested thread count to acheive ideal occupancy. // If the provided thread dimensions match this number, zero is returned. -int CompareOccupancy(int* initial_blocks, +int CUDAExecutor::CompareOccupancy(int* initial_blocks, const DeviceDescription& device_description, uint64 registers_per_thread, uint64 shared_memory_per_block, -- cgit v1.2.3 From 475b7715f16ad0f94fa9986a0eefc1b2cf2044bd Mon Sep 17 00:00:00 2001 From: Matt Conley Date: Tue, 4 Sep 2018 16:31:01 -0700 Subject: Recommended typo fix --- tensorflow/stream_executor/cuda/cuda_gpu_executor.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tensorflow/stream_executor') diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc index a961e9a6c4..ce2f1ce3ae 100644 --- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc +++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc @@ -483,7 +483,7 @@ void CUDAExecutor::VlogOccupancyInfo(const KernelBase &kernel, CompareOccupancy(&blocks_per_sm, device_description, regs_per_thread, smem_per_block, thread_dims, cufunc); if (suggested_threads != 0) { - VLOG(2) << "The cuda occupancy calculator reccommends using " + VLOG(2) << "The cuda occupancy calculator recommends using " << suggested_threads << " threads per block to acheive an occupancy of " << blocks_per_sm << " blocks per SM."; -- cgit v1.2.3 From d0574f6b25ab01052e093ab92612520a7e4ada8d Mon Sep 17 00:00:00 2001 From: Matt Conley Date: Thu, 6 Sep 2018 08:22:37 -0700 Subject: Fixed clang formatting --- tensorflow/stream_executor/cuda/cuda_gpu_executor.cc | 17 +++++++++-------- tensorflow/stream_executor/cuda/cuda_gpu_executor.h | 12 ++++++------ 2 files changed, 15 insertions(+), 14 deletions(-) (limited to 'tensorflow/stream_executor') diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc index ce2f1ce3ae..ef84d01a94 100644 --- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc +++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc @@ -493,10 +493,10 @@ void CUDAExecutor::VlogOccupancyInfo(const KernelBase &kernel, // Compute and return maximum blocks per core (occupancy) based on the // device description, some kernel characteristics and the number of threads per // block. If unable to compute occupancy, zero is returned. -int CUDAExecutor::CalculateOccupancy(const DeviceDescription& device_description, - uint64 registers_per_thread, - uint64 shared_memory_per_block, - const ThreadDim& thread_dims, CUfunction func) { +int CUDAExecutor::CalculateOccupancy( + const DeviceDescription& device_description, uint64 registers_per_thread, + uint64 shared_memory_per_block, const ThreadDim& thread_dims, + CUfunction func) { int suggested_blocks = 0; int suggested_threads = 0; CUresult err = @@ -509,10 +509,11 @@ int CUDAExecutor::CalculateOccupancy(const DeviceDescription& device_description // Compute and return the suggested thread count to acheive ideal occupancy. // If the provided thread dimensions match this number, zero is returned. int CUDAExecutor::CompareOccupancy(int* initial_blocks, - const DeviceDescription& device_description, - uint64 registers_per_thread, - uint64 shared_memory_per_block, - const ThreadDim& thread_dims, CUfunction func) { + const DeviceDescription& device_description, + uint64 registers_per_thread, + uint64 shared_memory_per_block, + const ThreadDim& thread_dims, + CUfunction func) { int suggested_blocks = 0; int suggested_threads = 0; CUresult err = diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h index e8ebbc3220..1481dcc19a 100644 --- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h +++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h @@ -71,16 +71,16 @@ class CUDAExecutor : public internal::StreamExecutorInterface { const KernelArgsArrayBase &args) override; int CalculateOccupancy(const DeviceDescription& device_description, + uint64 registers_per_thread, + uint64 shared_memory_per_block, + const ThreadDim& thread_dims, CUfunction func); + + int CompareOccupancy(int* initial_blocks, + const DeviceDescription& device_description, uint64 registers_per_thread, uint64 shared_memory_per_block, const ThreadDim& thread_dims, CUfunction func); - int CompareOccupancy(int* initial_blocks, - const DeviceDescription& device_description, - uint64 registers_per_thread, - uint64 shared_memory_per_block, - const ThreadDim& thread_dims, CUfunction func); - void *Allocate(uint64 size) override; void *AllocateSubBuffer(DeviceMemoryBase *mem, uint64 offset_bytes, -- cgit v1.2.3 From 6a5090b086bc9d665eb9e65f05eb94cdb58baaa2 Mon Sep 17 00:00:00 2001 From: Matt Conley Date: Thu, 6 Sep 2018 13:09:12 -0700 Subject: Fully fixed clang errors --- tensorflow/stream_executor/cuda/cuda_gpu_executor.cc | 12 ++++++------ tensorflow/stream_executor/cuda/cuda_gpu_executor.h | 10 +++++----- 2 files changed, 11 insertions(+), 11 deletions(-) (limited to 'tensorflow/stream_executor') diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc index ef84d01a94..9d5bcc7f77 100644 --- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc +++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc @@ -472,7 +472,7 @@ void CUDAExecutor::VlogOccupancyInfo(const KernelBase &kernel, const DeviceDescription &device_description = kernel.parent()->GetDeviceDescription(); - const CUDAKernel* cuda_kernel = AsCUDAKernel(&kernel); + const CUDAKernel *cuda_kernel = AsCUDAKernel(&kernel); CUfunction cufunc = cuda_kernel->AsCUDAFunctionValue(); int blocks_per_sm = CalculateOccupancy(device_description, regs_per_thread, @@ -494,8 +494,8 @@ void CUDAExecutor::VlogOccupancyInfo(const KernelBase &kernel, // device description, some kernel characteristics and the number of threads per // block. If unable to compute occupancy, zero is returned. int CUDAExecutor::CalculateOccupancy( - const DeviceDescription& device_description, uint64 registers_per_thread, - uint64 shared_memory_per_block, const ThreadDim& thread_dims, + const DeviceDescription &device_description, uint64 registers_per_thread, + uint64 shared_memory_per_block, const ThreadDim &thread_dims, CUfunction func) { int suggested_blocks = 0; int suggested_threads = 0; @@ -508,11 +508,11 @@ int CUDAExecutor::CalculateOccupancy( // Compute and return the suggested thread count to acheive ideal occupancy. // If the provided thread dimensions match this number, zero is returned. -int CUDAExecutor::CompareOccupancy(int* initial_blocks, - const DeviceDescription& device_description, +int CUDAExecutor::CompareOccupancy(int *initial_blocks, + const DeviceDescription &device_description, uint64 registers_per_thread, uint64 shared_memory_per_block, - const ThreadDim& thread_dims, + const ThreadDim &thread_dims, CUfunction func) { int suggested_blocks = 0; int suggested_threads = 0; diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h index 1481dcc19a..53b2a29ae7 100644 --- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h +++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h @@ -70,16 +70,16 @@ class CUDAExecutor : public internal::StreamExecutorInterface { const BlockDim &block_dims, const KernelBase &k, const KernelArgsArrayBase &args) override; - int CalculateOccupancy(const DeviceDescription& device_description, + int CalculateOccupancy(const DeviceDescription &device_description, uint64 registers_per_thread, uint64 shared_memory_per_block, - const ThreadDim& thread_dims, CUfunction func); + const ThreadDim &thread_dims, CUfunction func); - int CompareOccupancy(int* initial_blocks, - const DeviceDescription& device_description, + int CompareOccupancy(int *initial_blocks, + const DeviceDescription &device_description, uint64 registers_per_thread, uint64 shared_memory_per_block, - const ThreadDim& thread_dims, CUfunction func); + const ThreadDim &thread_dims, CUfunction func); void *Allocate(uint64 size) override; -- cgit v1.2.3