diff options
author | Dan Ringwalt <ringwalt@google.com> | 2016-11-04 09:33:10 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-11-04 10:42:50 -0700 |
commit | e706b3b1833d838f48b515da2705aa31bcf0dc84 (patch) | |
tree | 3d7e93e93faee1db6e83ad176061e71fde15d6e4 | |
parent | 7103c524b9ed9dc0f3a0e194cdfeb92beb42b5d9 (diff) |
Use aligned vector stores for random ops in CUDA.
Store the output generated by a single call to PhiloxRandom in a single vector, with partial specializations for float4, int4, double2, and long2. The unspecialized SampleCopier template does not use a vector, and loops over the individual outputs.
This speeds up RandomUniform (probably the most IO-bound random op) by about 33%. The other random ops (which are probably compute-bound) are slightly improved.
Benchmark Time(ns) CPU(ns) Iterations
----------------------------------------------------------
BM_gpu_RandomUniform/1M 226000 376443 1846 2.594G items/s
BM_gpu_RandomUniform/2M 407823 640758 1000 3.048G items/s
BM_gpu_RandomUniform/8M 1445044 1801758 401 4.336G items/s
BM_gpu_RandomNormal/1M 455218 725107 917 1.347G items/s
BM_gpu_RandomNormal/2M 816297 1131202 610 1.727G items/s
BM_gpu_RandomNormal/8M 3015506 3377695 213 2.313G items/s
BM_gpu_TruncatedNormal/1M 1015598 1332691 515 750.361M items/s
BM_gpu_TruncatedNormal/2M 1907191 2260752 312 884.661M items/s
BM_gpu_TruncatedNormal/8M 7292608 7789287 100 1.003G items/s
Change: 138205581
-rw-r--r-- | tensorflow/core/kernels/random_op_gpu.cu.cc | 106 |
1 files changed, 96 insertions, 10 deletions
diff --git a/tensorflow/core/kernels/random_op_gpu.cu.cc b/tensorflow/core/kernels/random_op_gpu.cu.cc index 0f1e637aae..5f7d9b7dd6 100644 --- a/tensorflow/core/kernels/random_op_gpu.cu.cc +++ b/tensorflow/core/kernels/random_op_gpu.cu.cc @@ -39,6 +39,89 @@ typedef Eigen::GpuDevice GPUDevice; template <class Distribution, bool VariableSamplesPerOutput> struct FillPhiloxRandomKernel; +template <typename T, int ElementCount> +class SampleCopier { + public: + inline __device__ void operator()( + T* buf, const tensorflow::random::Array<T, ElementCount>& array) const { +#pragma unroll + for (int i = 0; i < ElementCount; i++) { + buf[i] = array[i]; + } + } +}; + +template <> +class SampleCopier<float, 4> { + public: + // Copies the elements from the array to buf. buf must be 128-bit aligned, + // which is true for tensor data, and all offsets that are a multiple of the + // vector size (because the vectors are 128 bits long). + inline __device__ void operator()( + float* buf, const tensorflow::random::Array<float, 4>& array) const { + // NOTE(ringwalt): It's not safe to cast &array[0] to a float4, because they + // have 32-bit alignment vs 128-bit alignment. There seems to be no + // performance loss when assigning each element to a vector. + float4 vec; + vec.x = array[0]; + vec.y = array[1]; + vec.z = array[2]; + vec.w = array[3]; + float4* buf_vector = reinterpret_cast<float4*>(buf); + *buf_vector = vec; + } +}; + +template <> +class SampleCopier<int32, 4> { + public: + // Copies the elements from the array to buf. buf must be 128-bit aligned, + // which is true for tensor data, and all offsets that are a multiple of the + // vector size (because the vectors are 128 bits long). + inline __device__ void operator()( + int32* buf, const tensorflow::random::Array<int32, 4>& array) const { + int4 vec; + vec.x = array[0]; + vec.y = array[1]; + vec.z = array[2]; + vec.w = array[3]; + int4* buf_vector = reinterpret_cast<int4*>(buf); + *buf_vector = vec; + } +}; + +template <> +class SampleCopier<double, 2> { + public: + // Copies the elements from the array to buf. buf must be 128-bit aligned, + // which is true for tensor data, and all offsets that are a multiple of the + // vector size (because the vectors are 128 bits long). + inline __device__ void operator()( + double* buf, const tensorflow::random::Array<double, 2>& array) const { + double2 vec; + vec.x = array[0]; + vec.y = array[1]; + double2* buf_vector = reinterpret_cast<double2*>(buf); + *buf_vector = vec; + } +}; + +template <> +class SampleCopier<int64, 2> { + public: + // Copies the elements from the array to buf. buf must be 128-bit aligned, + // which is true for tensor data, and all offsets that are a multiple of the + // vector size (because the vectors are 128 bits long). + inline __device__ void operator()( + int64* buf, const tensorflow::random::Array<int64, 2>& array) const { + longlong2 vec; + vec.x = array[0]; + vec.y = array[1]; + longlong2* buf_vector = reinterpret_cast<longlong2*>(buf); + *buf_vector = vec; + } +}; + // A cuda kernel to fill the data with random numbers from the specified // distribution. Each output takes a fixed number of samples. template <class Distribution> @@ -53,20 +136,23 @@ struct FillPhiloxRandomKernel<Distribution, false> { int32 offset = thread_id * kGroupSize; gen.Skip(thread_id); - while (offset < size) { - typename Distribution::ResultType samples = dist(&gen); - - for (int i = 0; i < kGroupSize; ++i) { - if (offset >= size) { - return; - } - data[offset] = samples[i]; - ++offset; - } + const SampleCopier<T, kGroupSize> copier; + while (offset + kGroupSize <= size) { + const typename Distribution::ResultType samples = dist(&gen); + copier(&data[offset], samples); offset += (total_thread_count - 1) * kGroupSize; gen.Skip(total_thread_count - 1); } + + typename Distribution::ResultType samples = dist(&gen); + for (int i = 0; i < kGroupSize; ++i) { + if (offset >= size) { + return; + } + data[offset] = samples[i]; + ++offset; + } } }; |