aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/random_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/random_op.cc')
-rw-r--r--tensorflow/core/kernels/random_op.cc34
1 files changed, 19 insertions, 15 deletions
diff --git a/tensorflow/core/kernels/random_op.cc b/tensorflow/core/kernels/random_op.cc
index 04a53697c0..3810d817ca 100644
--- a/tensorflow/core/kernels/random_op.cc
+++ b/tensorflow/core/kernels/random_op.cc
@@ -489,13 +489,15 @@ class RandomGammaOp : public OpKernel {
Name("RandomGamma").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
RandomGammaOp<TYPE>)
-#define REGISTER_INT(IntType) \
- REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \
- .Device(DEVICE_CPU) \
- .HostMemory("shape") \
- .HostMemory("minval") \
- .HostMemory("maxval") \
- .TypeConstraint<IntType>("Tout"), \
+#define REGISTER_INT(IntType) \
+ template struct functor::FillPhiloxRandom< \
+ CPUDevice, random::UniformDistribution<random::PhiloxRandom, IntType>>; \
+ REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \
+ .Device(DEVICE_CPU) \
+ .HostMemory("shape") \
+ .HostMemory("minval") \
+ .HostMemory("maxval") \
+ .TypeConstraint<IntType>("Tout"), \
RandomUniformIntOp<CPUDevice, IntType>);
TF_CALL_half(REGISTER);
@@ -538,14 +540,16 @@ TF_CALL_int64(REGISTER_INT);
random::TruncatedNormalDistribution< \
random::SingleSampleAdapter<random::PhiloxRandom>, TYPE>>);
-#define REGISTER_INT(IntType) \
- REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \
- .Device(DEVICE_GPU) \
- .HostMemory("shape") \
- .HostMemory("minval") \
- .HostMemory("maxval") \
- .TypeConstraint<int32>("T") \
- .TypeConstraint<IntType>("Tout"), \
+#define REGISTER_INT(IntType) \
+ template struct functor::FillPhiloxRandom< \
+ GPUDevice, random::UniformDistribution<random::PhiloxRandom, IntType>>; \
+ REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("shape") \
+ .HostMemory("minval") \
+ .HostMemory("maxval") \
+ .TypeConstraint<int32>("T") \
+ .TypeConstraint<IntType>("Tout"), \
RandomUniformIntOp<GPUDevice, IntType>);
TF_CALL_half(REGISTER);