diff options
Diffstat (limited to 'tensorflow/core/ops/random_ops.cc')
-rw-r--r-- | tensorflow/core/ops/random_ops.cc | 44 |
1 files changed, 44 insertions, 0 deletions
diff --git a/tensorflow/core/ops/random_ops.cc b/tensorflow/core/ops/random_ops.cc index 776523f33f..7b2da9d8e6 100644 --- a/tensorflow/core/ops/random_ops.cc +++ b/tensorflow/core/ops/random_ops.cc @@ -276,4 +276,48 @@ output: A tensor with shape `shape + shape(alpha)`. Each slice `alpha[i0, i1, ...iN]`. The dtype of the output matches the dtype of alpha. )doc"); +REGISTER_OP("RandomPoisson") + .SetIsStateful() + .Input("shape: S") + .Input("rate: dtype") + .Output("output: dtype") + .Attr("seed: int = 0") + .Attr("seed2: int = 0") + .Attr("S: {int32, int64}") + .Attr("dtype: {half, float, double}") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle out; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out)); + TF_RETURN_IF_ERROR(c->Concatenate(out, c->input(1), &out)); + c->set_output(0, out); + return Status::OK(); + }) + .Doc(R"doc( +Outputs random values from the Poisson distribution(s) described by rate. + +This op uses two algorithms, depending on rate. If rate >= 10, then +the algorithm by Hormann is used to acquire samples via +transformation-rejection. +See http://www.sciencedirect.com/science/article/pii/0167668793909974. + +Otherwise, Knuth's algorithm is used to acquire samples via multiplying uniform +random variables. +See Donald E. Knuth (1969). Seminumerical Algorithms. The Art of Computer +Programming, Volume 2. Addison Wesley + +shape: 1-D integer tensor. Shape of independent samples to draw from each + distribution described by the shape parameters given in rate. +rate: A tensor in which each scalar is a "rate" parameter describing the + associated poisson distribution. +seed: If either `seed` or `seed2` are set to be non-zero, the random number + generator is seeded by the given seed. Otherwise, it is seeded by a + random seed. +seed2: A second seed to avoid seed collision. + +output: A tensor with shape `shape + shape(rate)`. Each slice + `[:, ..., :, i0, i1, ...iN]` contains the samples drawn for + `rate[i0, i1, ...iN]`. The dtype of the output matches the dtype of + rate. +)doc"); + } // namespace tensorflow |