From 1e104d80826fed95f9fad6f07f68e35cae3527b2 Mon Sep 17 00:00:00 2001 From: Geoffrey Irving Date: Wed, 19 Sep 2018 09:33:19 -0700 Subject: Expand stateless random generators to match their stateful cousins stateless_random_uniform now take minval+maxval and handles ints, and stateless_normal/stateless_truncated_normal take mean+stddev. Additionally, all of the stateless functions now have proper doc strings. This is step one of moving stateless random numbers out of contrib. --- .../api_def_StatelessRandomUniformInt.pbtxt | 46 ++++++ tensorflow/core/kernels/random_op.cc | 34 +++-- tensorflow/core/kernels/stateless_random_ops.cc | 155 +++++++++++++-------- tensorflow/core/ops/stateless_random_ops.cc | 53 ++++--- 4 files changed, 192 insertions(+), 96 deletions(-) create mode 100644 tensorflow/core/api_def/base_api/api_def_StatelessRandomUniformInt.pbtxt (limited to 'tensorflow/core') diff --git a/tensorflow/core/api_def/base_api/api_def_StatelessRandomUniformInt.pbtxt b/tensorflow/core/api_def/base_api/api_def_StatelessRandomUniformInt.pbtxt new file mode 100644 index 0000000000..b6a6dbdf54 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_StatelessRandomUniformInt.pbtxt @@ -0,0 +1,46 @@ +op { + graph_op_name: "StatelessRandomUniformInt" + visibility: HIDDEN + in_arg { + name: "shape" + description: <("T"), \ RandomGammaOp) -#define REGISTER_INT(IntType) \ - REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \ - .Device(DEVICE_CPU) \ - .HostMemory("shape") \ - .HostMemory("minval") \ - .HostMemory("maxval") \ - .TypeConstraint("Tout"), \ +#define REGISTER_INT(IntType) \ + template struct functor::FillPhiloxRandom< \ + CPUDevice, random::UniformDistribution>; \ + REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \ + .Device(DEVICE_CPU) \ + .HostMemory("shape") \ + .HostMemory("minval") \ + .HostMemory("maxval") \ + .TypeConstraint("Tout"), \ RandomUniformIntOp); TF_CALL_half(REGISTER); @@ -538,14 +540,16 @@ TF_CALL_int64(REGISTER_INT); random::TruncatedNormalDistribution< \ random::SingleSampleAdapter, TYPE>>); -#define REGISTER_INT(IntType) \ - REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \ - .Device(DEVICE_GPU) \ - .HostMemory("shape") \ - .HostMemory("minval") \ - .HostMemory("maxval") \ - .TypeConstraint("T") \ - .TypeConstraint("Tout"), \ +#define REGISTER_INT(IntType) \ + template struct functor::FillPhiloxRandom< \ + GPUDevice, random::UniformDistribution>; \ + REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \ + .Device(DEVICE_GPU) \ + .HostMemory("shape") \ + .HostMemory("minval") \ + .HostMemory("maxval") \ + .TypeConstraint("T") \ + .TypeConstraint("Tout"), \ RandomUniformIntOp); TF_CALL_half(REGISTER); diff --git a/tensorflow/core/kernels/stateless_random_ops.cc b/tensorflow/core/kernels/stateless_random_ops.cc index eab176c7fb..925f5291a6 100644 --- a/tensorflow/core/kernels/stateless_random_ops.cc +++ b/tensorflow/core/kernels/stateless_random_ops.cc @@ -113,74 +113,109 @@ class StatelessRandomOp : public StatelessRandomOpBase { } }; -#define REGISTER(TYPE) \ - REGISTER_KERNEL_BUILDER( \ - Name("StatelessRandomUniform") \ - .Device(DEVICE_CPU) \ - .HostMemory("shape") \ - .TypeConstraint("dtype"), \ - StatelessRandomOp >); \ - REGISTER_KERNEL_BUILDER( \ - Name("StatelessRandomNormal") \ - .Device(DEVICE_CPU) \ - .HostMemory("shape") \ - .TypeConstraint("dtype"), \ - StatelessRandomOp >); \ - REGISTER_KERNEL_BUILDER( \ - Name("StatelessTruncatedNormal") \ - .Device(DEVICE_CPU) \ - .HostMemory("shape") \ - .TypeConstraint("dtype"), \ - StatelessRandomOp< \ - CPUDevice, \ - random::TruncatedNormalDistribution< \ - random::SingleSampleAdapter, TYPE> >); +template +class StatelessRandomUniformIntOp : public StatelessRandomOpBase { + public: + using StatelessRandomOpBase::StatelessRandomOpBase; -TF_CALL_half(REGISTER); -TF_CALL_float(REGISTER); -TF_CALL_double(REGISTER); + void Fill(OpKernelContext* context, random::PhiloxRandom random, + Tensor* output) override { + const Tensor& minval = context->input(2); + const Tensor& maxval = context->input(3); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(minval.shape()), + errors::InvalidArgument("minval must be 0-D, got shape ", + minval.shape().DebugString())); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(maxval.shape()), + errors::InvalidArgument("maxval must be 0-D, got shape ", + maxval.shape().DebugString())); + + // Verify that minval < maxval. Note that we'll never reach this point for + // empty output. Zero impossible things are fine. + const auto lo = minval.scalar()(); + const auto hi = maxval.scalar()(); + OP_REQUIRES( + context, lo < hi, + errors::InvalidArgument("Need minval < maxval, got ", lo, " >= ", hi)); + + // Build distribution + typedef random::UniformDistribution + Distribution; + Distribution dist(lo, hi); + + auto flat = output->flat(); + // Reuse the compute kernels from the stateful random ops + functor::FillPhiloxRandom()( + context, context->eigen_device(), random, flat.data(), + flat.size(), dist); + } +}; -#undef REGISTER +#define REGISTER(DEVICE, TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("StatelessRandomUniform") \ + .Device(DEVICE_##DEVICE) \ + .HostMemory("shape") \ + .HostMemory("seed") \ + .TypeConstraint("dtype"), \ + StatelessRandomOp >); \ + REGISTER_KERNEL_BUILDER( \ + Name("StatelessRandomNormal") \ + .Device(DEVICE_##DEVICE) \ + .HostMemory("shape") \ + .HostMemory("seed") \ + .TypeConstraint("dtype"), \ + StatelessRandomOp >); \ + REGISTER_KERNEL_BUILDER( \ + Name("StatelessTruncatedNormal") \ + .Device(DEVICE_##DEVICE) \ + .HostMemory("shape") \ + .HostMemory("seed") \ + .TypeConstraint("dtype"), \ + StatelessRandomOp< \ + DEVICE##Device, \ + random::TruncatedNormalDistribution< \ + random::SingleSampleAdapter, TYPE> >); + +#define REGISTER_INT(DEVICE, TYPE) \ + REGISTER_KERNEL_BUILDER(Name("StatelessRandomUniformInt") \ + .Device(DEVICE_##DEVICE) \ + .HostMemory("shape") \ + .HostMemory("seed") \ + .HostMemory("minval") \ + .HostMemory("maxval") \ + .TypeConstraint("dtype"), \ + StatelessRandomUniformIntOp); + +#define REGISTER_CPU(TYPE) REGISTER(CPU, TYPE) +#define REGISTER_GPU(TYPE) REGISTER(GPU, TYPE) +#define REGISTER_INT_CPU(TYPE) REGISTER_INT(CPU, TYPE) +#define REGISTER_INT_GPU(TYPE) REGISTER_INT(GPU, TYPE) + +TF_CALL_half(REGISTER_CPU); +TF_CALL_bfloat16(REGISTER_CPU); +TF_CALL_float(REGISTER_CPU); +TF_CALL_double(REGISTER_CPU); +TF_CALL_int32(REGISTER_INT_CPU); +TF_CALL_int64(REGISTER_INT_CPU); #if GOOGLE_CUDA -#define REGISTER(TYPE) \ - REGISTER_KERNEL_BUILDER( \ - Name("StatelessRandomUniform") \ - .Device(DEVICE_GPU) \ - .HostMemory("shape") \ - .HostMemory("seed") \ - .TypeConstraint("dtype"), \ - StatelessRandomOp >); \ - REGISTER_KERNEL_BUILDER( \ - Name("StatelessRandomNormal") \ - .Device(DEVICE_GPU) \ - .HostMemory("shape") \ - .HostMemory("seed") \ - .TypeConstraint("dtype"), \ - StatelessRandomOp >); \ - REGISTER_KERNEL_BUILDER( \ - Name("StatelessTruncatedNormal") \ - .Device(DEVICE_GPU) \ - .HostMemory("shape") \ - .HostMemory("seed") \ - .TypeConstraint("dtype"), \ - StatelessRandomOp< \ - GPUDevice, \ - random::TruncatedNormalDistribution< \ - random::SingleSampleAdapter, TYPE> >); +TF_CALL_half(REGISTER_GPU); +TF_CALL_float(REGISTER_GPU); +TF_CALL_double(REGISTER_GPU); +TF_CALL_int32(REGISTER_INT_GPU); +TF_CALL_int64(REGISTER_INT_GPU); -TF_CALL_half(REGISTER); -TF_CALL_float(REGISTER); -TF_CALL_double(REGISTER); +#endif // GOOGLE_CUDA #undef REGISTER - -#endif // GOOGLE_CUDA +#undef REGISTER_INT +#undef REGISTER_CPU +#undef REGISTER_GPU +#undef REGISTER_INT_CPU +#undef REGISTER_INT_GPU } // namespace diff --git a/tensorflow/core/ops/stateless_random_ops.cc b/tensorflow/core/ops/stateless_random_ops.cc index 742709fb18..f919a21d60 100644 --- a/tensorflow/core/ops/stateless_random_ops.cc +++ b/tensorflow/core/ops/stateless_random_ops.cc @@ -19,42 +19,55 @@ limitations under the License. namespace tensorflow { using shape_inference::DimensionHandle; +using shape_inference::InferenceContext; using shape_inference::ShapeHandle; -static Status StatelessShape(shape_inference::InferenceContext* context) { +static Status StatelessShape(InferenceContext* c) { // Check seed shape ShapeHandle seed; - TF_RETURN_IF_ERROR(context->WithRank(context->input(1), 1, &seed)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &seed)); DimensionHandle unused; - TF_RETURN_IF_ERROR(context->WithValue(context->Dim(seed, 0), 2, &unused)); + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(seed, 0), 2, &unused)); // Set output shape ShapeHandle out; - TF_RETURN_IF_ERROR(context->MakeShapeFromShapeTensor(0, &out)); - context->set_output(0, out); + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out)); + c->set_output(0, out); return Status::OK(); } -#define REGISTER_STATELESS_OP(name) \ - REGISTER_OP(name) \ - .Input("shape: T") \ - .Input("seed: Tseed") \ - .Output("output: dtype") \ - .Attr("dtype: {half,float,double} = DT_FLOAT") \ - .Attr("T: {int32, int64} = DT_INT32") \ - .Attr("Tseed: {int32, int64} = DT_INT64") \ +#define REGISTER_STATELESS_OP(name) \ + REGISTER_OP(name) \ + .Input("shape: T") \ + .Input("seed: Tseed") \ + .Output("output: dtype") \ + .Attr("dtype: {half,bfloat16,float,double} = DT_FLOAT") \ + .Attr("T: {int32, int64} = DT_INT32") \ + .Attr("Tseed: {int32, int64} = DT_INT64") \ .SetShapeFn(StatelessShape) -// This op is exposed through contrib/stateless only. The interface may change. REGISTER_STATELESS_OP("StatelessRandomUniform"); - -// This op is exposed through contrib/stateless only. The interface may change. REGISTER_STATELESS_OP("StatelessRandomNormal"); - -// This op is exposed through contrib/stateless only. The interface may change. REGISTER_STATELESS_OP("StatelessTruncatedNormal"); -// This op is exposed through contrib/stateless only. The interface may change. +#undef REGISTER_STATELESS_OP + +REGISTER_OP("StatelessRandomUniformInt") + .Input("shape: T") + .Input("seed: Tseed") + .Input("minval: dtype") + .Input("maxval: dtype") + .Output("output: dtype") + .Attr("dtype: {int32, int64}") + .Attr("T: {int32, int64}") + .Attr("Tseed: {int32, int64} = DT_INT64") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + return StatelessShape(c); + }); + REGISTER_OP("StatelessMultinomial") .Input("logits: T") .Input("num_samples: int32") @@ -80,6 +93,4 @@ REGISTER_OP("StatelessMultinomial") return Status::OK(); }); -#undef REGISTER_STATELESS_OP - } // namespace tensorflow -- cgit v1.2.3