diff options
author | 2018-09-19 09:33:19 -0700 | |
---|---|---|
committer | 2018-10-05 15:02:11 -0700 | |
commit | 1e104d80826fed95f9fad6f07f68e35cae3527b2 (patch) | |
tree | bb934f3e78f55f53fadbdba8408e550f1f61e171 /tensorflow/core/kernels | |
parent | 80c9eec9b2475630f83a596f77a906c8075f8e6c (diff) |
Expand stateless random generators to match their stateful cousins
stateless_random_uniform now take minval+maxval and handles ints,
and stateless_normal/stateless_truncated_normal take mean+stddev.
Additionally, all of the stateless functions now have proper doc
strings.
This is step one of moving stateless random numbers out of contrib.
Diffstat (limited to 'tensorflow/core/kernels')
-rw-r--r-- | tensorflow/core/kernels/random_op.cc | 34 | ||||
-rw-r--r-- | tensorflow/core/kernels/stateless_random_ops.cc | 155 |
2 files changed, 114 insertions, 75 deletions
diff --git a/tensorflow/core/kernels/random_op.cc b/tensorflow/core/kernels/random_op.cc index 04a53697c0..3810d817ca 100644 --- a/tensorflow/core/kernels/random_op.cc +++ b/tensorflow/core/kernels/random_op.cc @@ -489,13 +489,15 @@ class RandomGammaOp : public OpKernel { Name("RandomGamma").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \ RandomGammaOp<TYPE>) -#define REGISTER_INT(IntType) \ - REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \ - .Device(DEVICE_CPU) \ - .HostMemory("shape") \ - .HostMemory("minval") \ - .HostMemory("maxval") \ - .TypeConstraint<IntType>("Tout"), \ +#define REGISTER_INT(IntType) \ + template struct functor::FillPhiloxRandom< \ + CPUDevice, random::UniformDistribution<random::PhiloxRandom, IntType>>; \ + REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \ + .Device(DEVICE_CPU) \ + .HostMemory("shape") \ + .HostMemory("minval") \ + .HostMemory("maxval") \ + .TypeConstraint<IntType>("Tout"), \ RandomUniformIntOp<CPUDevice, IntType>); TF_CALL_half(REGISTER); @@ -538,14 +540,16 @@ TF_CALL_int64(REGISTER_INT); random::TruncatedNormalDistribution< \ random::SingleSampleAdapter<random::PhiloxRandom>, TYPE>>); -#define REGISTER_INT(IntType) \ - REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \ - .Device(DEVICE_GPU) \ - .HostMemory("shape") \ - .HostMemory("minval") \ - .HostMemory("maxval") \ - .TypeConstraint<int32>("T") \ - .TypeConstraint<IntType>("Tout"), \ +#define REGISTER_INT(IntType) \ + template struct functor::FillPhiloxRandom< \ + GPUDevice, random::UniformDistribution<random::PhiloxRandom, IntType>>; \ + REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \ + .Device(DEVICE_GPU) \ + .HostMemory("shape") \ + .HostMemory("minval") \ + .HostMemory("maxval") \ + .TypeConstraint<int32>("T") \ + .TypeConstraint<IntType>("Tout"), \ RandomUniformIntOp<GPUDevice, IntType>); TF_CALL_half(REGISTER); diff --git a/tensorflow/core/kernels/stateless_random_ops.cc b/tensorflow/core/kernels/stateless_random_ops.cc index eab176c7fb..925f5291a6 100644 --- a/tensorflow/core/kernels/stateless_random_ops.cc +++ b/tensorflow/core/kernels/stateless_random_ops.cc @@ -113,74 +113,109 @@ class StatelessRandomOp : public StatelessRandomOpBase { } }; -#define REGISTER(TYPE) \ - REGISTER_KERNEL_BUILDER( \ - Name("StatelessRandomUniform") \ - .Device(DEVICE_CPU) \ - .HostMemory("shape") \ - .TypeConstraint<TYPE>("dtype"), \ - StatelessRandomOp<CPUDevice, random::UniformDistribution< \ - random::PhiloxRandom, TYPE> >); \ - REGISTER_KERNEL_BUILDER( \ - Name("StatelessRandomNormal") \ - .Device(DEVICE_CPU) \ - .HostMemory("shape") \ - .TypeConstraint<TYPE>("dtype"), \ - StatelessRandomOp<CPUDevice, random::NormalDistribution< \ - random::PhiloxRandom, TYPE> >); \ - REGISTER_KERNEL_BUILDER( \ - Name("StatelessTruncatedNormal") \ - .Device(DEVICE_CPU) \ - .HostMemory("shape") \ - .TypeConstraint<TYPE>("dtype"), \ - StatelessRandomOp< \ - CPUDevice, \ - random::TruncatedNormalDistribution< \ - random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >); +template <typename Device, typename IntType> +class StatelessRandomUniformIntOp : public StatelessRandomOpBase { + public: + using StatelessRandomOpBase::StatelessRandomOpBase; -TF_CALL_half(REGISTER); -TF_CALL_float(REGISTER); -TF_CALL_double(REGISTER); + void Fill(OpKernelContext* context, random::PhiloxRandom random, + Tensor* output) override { + const Tensor& minval = context->input(2); + const Tensor& maxval = context->input(3); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(minval.shape()), + errors::InvalidArgument("minval must be 0-D, got shape ", + minval.shape().DebugString())); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(maxval.shape()), + errors::InvalidArgument("maxval must be 0-D, got shape ", + maxval.shape().DebugString())); + + // Verify that minval < maxval. Note that we'll never reach this point for + // empty output. Zero impossible things are fine. + const auto lo = minval.scalar<IntType>()(); + const auto hi = maxval.scalar<IntType>()(); + OP_REQUIRES( + context, lo < hi, + errors::InvalidArgument("Need minval < maxval, got ", lo, " >= ", hi)); + + // Build distribution + typedef random::UniformDistribution<random::PhiloxRandom, IntType> + Distribution; + Distribution dist(lo, hi); + + auto flat = output->flat<IntType>(); + // Reuse the compute kernels from the stateful random ops + functor::FillPhiloxRandom<Device, Distribution>()( + context, context->eigen_device<Device>(), random, flat.data(), + flat.size(), dist); + } +}; -#undef REGISTER +#define REGISTER(DEVICE, TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("StatelessRandomUniform") \ + .Device(DEVICE_##DEVICE) \ + .HostMemory("shape") \ + .HostMemory("seed") \ + .TypeConstraint<TYPE>("dtype"), \ + StatelessRandomOp<DEVICE##Device, random::UniformDistribution< \ + random::PhiloxRandom, TYPE> >); \ + REGISTER_KERNEL_BUILDER( \ + Name("StatelessRandomNormal") \ + .Device(DEVICE_##DEVICE) \ + .HostMemory("shape") \ + .HostMemory("seed") \ + .TypeConstraint<TYPE>("dtype"), \ + StatelessRandomOp<DEVICE##Device, random::NormalDistribution< \ + random::PhiloxRandom, TYPE> >); \ + REGISTER_KERNEL_BUILDER( \ + Name("StatelessTruncatedNormal") \ + .Device(DEVICE_##DEVICE) \ + .HostMemory("shape") \ + .HostMemory("seed") \ + .TypeConstraint<TYPE>("dtype"), \ + StatelessRandomOp< \ + DEVICE##Device, \ + random::TruncatedNormalDistribution< \ + random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >); + +#define REGISTER_INT(DEVICE, TYPE) \ + REGISTER_KERNEL_BUILDER(Name("StatelessRandomUniformInt") \ + .Device(DEVICE_##DEVICE) \ + .HostMemory("shape") \ + .HostMemory("seed") \ + .HostMemory("minval") \ + .HostMemory("maxval") \ + .TypeConstraint<TYPE>("dtype"), \ + StatelessRandomUniformIntOp<DEVICE##Device, TYPE>); + +#define REGISTER_CPU(TYPE) REGISTER(CPU, TYPE) +#define REGISTER_GPU(TYPE) REGISTER(GPU, TYPE) +#define REGISTER_INT_CPU(TYPE) REGISTER_INT(CPU, TYPE) +#define REGISTER_INT_GPU(TYPE) REGISTER_INT(GPU, TYPE) + +TF_CALL_half(REGISTER_CPU); +TF_CALL_bfloat16(REGISTER_CPU); +TF_CALL_float(REGISTER_CPU); +TF_CALL_double(REGISTER_CPU); +TF_CALL_int32(REGISTER_INT_CPU); +TF_CALL_int64(REGISTER_INT_CPU); #if GOOGLE_CUDA -#define REGISTER(TYPE) \ - REGISTER_KERNEL_BUILDER( \ - Name("StatelessRandomUniform") \ - .Device(DEVICE_GPU) \ - .HostMemory("shape") \ - .HostMemory("seed") \ - .TypeConstraint<TYPE>("dtype"), \ - StatelessRandomOp<GPUDevice, random::UniformDistribution< \ - random::PhiloxRandom, TYPE> >); \ - REGISTER_KERNEL_BUILDER( \ - Name("StatelessRandomNormal") \ - .Device(DEVICE_GPU) \ - .HostMemory("shape") \ - .HostMemory("seed") \ - .TypeConstraint<TYPE>("dtype"), \ - StatelessRandomOp<GPUDevice, random::NormalDistribution< \ - random::PhiloxRandom, TYPE> >); \ - REGISTER_KERNEL_BUILDER( \ - Name("StatelessTruncatedNormal") \ - .Device(DEVICE_GPU) \ - .HostMemory("shape") \ - .HostMemory("seed") \ - .TypeConstraint<TYPE>("dtype"), \ - StatelessRandomOp< \ - GPUDevice, \ - random::TruncatedNormalDistribution< \ - random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >); +TF_CALL_half(REGISTER_GPU); +TF_CALL_float(REGISTER_GPU); +TF_CALL_double(REGISTER_GPU); +TF_CALL_int32(REGISTER_INT_GPU); +TF_CALL_int64(REGISTER_INT_GPU); -TF_CALL_half(REGISTER); -TF_CALL_float(REGISTER); -TF_CALL_double(REGISTER); +#endif // GOOGLE_CUDA #undef REGISTER - -#endif // GOOGLE_CUDA +#undef REGISTER_INT +#undef REGISTER_CPU +#undef REGISTER_GPU +#undef REGISTER_INT_CPU +#undef REGISTER_INT_GPU } // namespace |