diff options
Diffstat (limited to 'tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc')
-rw-r--r-- | tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc | 12 |
1 files changed, 10 insertions, 2 deletions
diff --git a/tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc b/tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc index 42d4744069..8b85bd4ebe 100644 --- a/tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc +++ b/tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc @@ -29,7 +29,14 @@ limitations under the License. #include "tensorflow/core/lib/random/random_distributions.h" #include "tensorflow/core/util/cuda_kernel_helper.h" +#ifdef COMPILER_MSVC +// msvc does not support unroll. One could try the loop pragma but we need to +// take a closer look if this generates better code in this case. For now let +// the compiler take care of of it. +#define UNROLL +#else #define UNROLL _Pragma("unroll") +#endif namespace tensorflow { @@ -99,6 +106,7 @@ __global__ void __launch_bounds__(1024) Eigen::numext::exp(T(0.5) + (normMin * (normMin - sqrtFactor)) / T(4)) / (normMin + sqrtFactor); const T diff = normMax - normMin; + const T two = T(2.0); // Validate the normalized min and max, because the originals may have been // flipped already. @@ -124,7 +132,7 @@ __global__ void __launch_bounds__(1024) z[i] = rand[i] * diff + normMin; } UNROLL for (int i = 0; i < kDistSize; i++) { - g[i] = (plusFactor - z[i] * z[i]) / 2.0; + g[i] = (plusFactor - z[i] * z[i]) / two; } const auto u = dist(&gen); @@ -161,7 +169,7 @@ __global__ void __launch_bounds__(1024) UNROLL for (int i = 0; i < kDistSize; i += 2) { const T z = -Eigen::numext::log(rand[i]) / alpha + normMin; const T x = normMin < alpha ? alpha - z : normMin - alpha; - const T g = Eigen::numext::exp(-x * x / 2.0); + const T g = Eigen::numext::exp(-x * x / two); const T u = rand[i + 1]; if ((u <= g && z < normMax) || numIterations + 1 >= kMaxIterations) { data[offset] = z * stddev + mean; |