diff options
author | 2017-12-21 00:33:01 -0800 | |
---|---|---|
committer | 2017-12-21 00:36:24 -0800 | |
commit | 1ad680d3a36e9742229748d5c5d3bddb1d0a2578 (patch) | |
tree | b33384c28cecee2be360426e5bd59c9c27c846cc /tensorflow/core/util/cuda_kernel_helper_test.cu.cc | |
parent | 127f1b21dccf4eb46b6cf80657dfa340ed1c1ede (diff) |
Roll CL 177989542 forward with fix: Wrappers for CUDA 9 warp-synchronous intrinsics.
PiperOrigin-RevId: 179782067
Diffstat (limited to 'tensorflow/core/util/cuda_kernel_helper_test.cu.cc')
-rw-r--r-- | tensorflow/core/util/cuda_kernel_helper_test.cu.cc | 60 |
1 files changed, 54 insertions, 6 deletions
diff --git a/tensorflow/core/util/cuda_kernel_helper_test.cu.cc b/tensorflow/core/util/cuda_kernel_helper_test.cu.cc index 6991554eff..bd4c356ea0 100644 --- a/tensorflow/core/util/cuda_kernel_helper_test.cu.cc +++ b/tensorflow/core/util/cuda_kernel_helper_test.cu.cc @@ -52,11 +52,11 @@ __global__ void Count1D(CudaLaunchConfig config, int bufsize, int* outbuf) { } } __global__ void Count2D(Cuda2DLaunchConfig config, int bufsize, int* outbuf) { - CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) { + CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count.x, X) { if (x < 0) { // x might overflow when testing extreme case break; } - CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) { + CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count.y, Y) { if (y < 0) { // y might overflow when testing extreme case break; } @@ -66,15 +66,15 @@ __global__ void Count2D(Cuda2DLaunchConfig config, int bufsize, int* outbuf) { } } __global__ void Count3D(Cuda3DLaunchConfig config, int bufsize, int* outbuf) { - CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) { + CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count.x, X) { if (x < 0) { // x might overflow when testing extreme case break; } - CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) { + CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count.y, Y) { if (y < 0) { // y might overflow when testing extreme case break; } - CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count, z) { + CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count.z, Z) { if (z < 0) { // z might overflow when testing extreme case break; } @@ -87,6 +87,44 @@ __global__ void Count3D(Cuda3DLaunchConfig config, int bufsize, int* outbuf) { } } +__global__ void CudaShuffleGetSrcLaneTest(unsigned* failure_count) { + unsigned lane_id = CudaLaneId(); + for (int width = warpSize; width > 1; width /= 2) { + auto check_result = [&](const char* op_name, int param, unsigned actual, + unsigned expected) { + if (actual != expected) { + printf("Cuda%sGetSrcLane(%d, %d) for lane %d returned %d, not %d\n", + op_name, param, width, lane_id, actual, expected); + CudaAtomicAdd(failure_count, 1); + } + }; + for (int src_lane = -warpSize; src_lane <= warpSize; ++src_lane) { + unsigned actual_lane = detail::CudaShuffleGetSrcLane(src_lane, width); + unsigned expect_lane = + CudaShuffleSync(kCudaWarpAll, lane_id, src_lane, width); + check_result("Shuffle", src_lane, actual_lane, expect_lane); + } + for (unsigned delta = 0; delta <= warpSize; ++delta) { + unsigned actual_lane = detail::CudaShuffleUpGetSrcLane(delta, width); + unsigned expect_lane = + CudaShuffleUpSync(kCudaWarpAll, lane_id, delta, width); + check_result("ShuffleUp", delta, actual_lane, expect_lane); + } + for (unsigned delta = 0; delta <= warpSize; ++delta) { + unsigned actual_lane = detail::CudaShuffleDownGetSrcLane(delta, width); + unsigned expect_lane = + CudaShuffleDownSync(kCudaWarpAll, lane_id, delta, width); + check_result("ShuffleDown", delta, actual_lane, expect_lane); + } + for (int lane_lane = warpSize; lane_lane > 0; lane_lane /= 2) { + unsigned actual_lane = detail::CudaShuffleXorGetSrcLane(lane_lane, width); + unsigned expect_lane = + CudaShuffleXorSync(kCudaWarpAll, lane_id, lane_lane, width); + check_result("ShuffleXor", lane_lane, actual_lane, expect_lane); + } + } +} + } // namespace class CudaLaunchConfigTest : public ::testing::Test { @@ -94,7 +132,7 @@ class CudaLaunchConfigTest : public ::testing::Test { const int bufsize = 1024; int* outbuf = nullptr; Eigen::CudaStreamDevice stream; - GPUDevice d = GPUDevice(&stream); + Eigen::GpuDevice d = Eigen::GpuDevice(&stream); virtual void SetUp() { cudaError_t err = cudaMallocManaged(&outbuf, sizeof(int) * bufsize); @@ -229,6 +267,16 @@ TEST_F(CudaLaunchConfigTest, GetCuda3DLaunchConfig) { #undef TEST_LAUNCH_PARAMETER } +TEST(CudaDeviceFunctionsTest, ShuffleGetSrcLane) { + unsigned* failure_count; + ASSERT_EQ(cudaMallocManaged(&failure_count, sizeof(unsigned)), cudaSuccess); + *failure_count = 0; + CudaShuffleGetSrcLaneTest<<<1, 32>>>(failure_count); + ASSERT_EQ(cudaDeviceSynchronize(), cudaSuccess); + ASSERT_EQ(*failure_count, 0); + cudaFree(failure_count); +} + } // namespace tensorflow #endif // GOOGLE_CUDA |