aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/random_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/ops/random_ops.cc')
-rw-r--r--tensorflow/core/ops/random_ops.cc44
1 files changed, 44 insertions, 0 deletions
diff --git a/tensorflow/core/ops/random_ops.cc b/tensorflow/core/ops/random_ops.cc
index 776523f33f..7b2da9d8e6 100644
--- a/tensorflow/core/ops/random_ops.cc
+++ b/tensorflow/core/ops/random_ops.cc
@@ -276,4 +276,48 @@ output: A tensor with shape `shape + shape(alpha)`. Each slice
`alpha[i0, i1, ...iN]`. The dtype of the output matches the dtype of alpha.
)doc");
+REGISTER_OP("RandomPoisson")
+ .SetIsStateful()
+ .Input("shape: S")
+ .Input("rate: dtype")
+ .Output("output: dtype")
+ .Attr("seed: int = 0")
+ .Attr("seed2: int = 0")
+ .Attr("S: {int32, int64}")
+ .Attr("dtype: {half, float, double}")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle out;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
+ TF_RETURN_IF_ERROR(c->Concatenate(out, c->input(1), &out));
+ c->set_output(0, out);
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Outputs random values from the Poisson distribution(s) described by rate.
+
+This op uses two algorithms, depending on rate. If rate >= 10, then
+the algorithm by Hormann is used to acquire samples via
+transformation-rejection.
+See http://www.sciencedirect.com/science/article/pii/0167668793909974.
+
+Otherwise, Knuth's algorithm is used to acquire samples via multiplying uniform
+random variables.
+See Donald E. Knuth (1969). Seminumerical Algorithms. The Art of Computer
+Programming, Volume 2. Addison Wesley
+
+shape: 1-D integer tensor. Shape of independent samples to draw from each
+ distribution described by the shape parameters given in rate.
+rate: A tensor in which each scalar is a "rate" parameter describing the
+ associated poisson distribution.
+seed: If either `seed` or `seed2` are set to be non-zero, the random number
+ generator is seeded by the given seed. Otherwise, it is seeded by a
+ random seed.
+seed2: A second seed to avoid seed collision.
+
+output: A tensor with shape `shape + shape(rate)`. Each slice
+ `[:, ..., :, i0, i1, ...iN]` contains the samples drawn for
+ `rate[i0, i1, ...iN]`. The dtype of the output matches the dtype of
+ rate.
+)doc");
+
} // namespace tensorflow