aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/BUILD1
-rw-r--r--tensorflow/core/graph/testlib.cc10
-rw-r--r--tensorflow/core/graph/testlib.h4
-rw-r--r--tensorflow/core/kernels/BUILD27
-rw-r--r--tensorflow/core/kernels/random_op.cc2
-rw-r--r--tensorflow/core/kernels/random_poisson_op.cc357
-rw-r--r--tensorflow/core/kernels/random_poisson_op.h31
-rw-r--r--tensorflow/core/kernels/random_poisson_op_test.cc82
-rw-r--r--tensorflow/core/ops/random_ops.cc44
-rw-r--r--tensorflow/core/ops/random_ops_test.cc15
10 files changed, 572 insertions, 1 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 2c75353ffe..4dd9bffe80 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -635,6 +635,7 @@ cc_library(
"//tensorflow/core/kernels:parameterized_truncated_normal_op",
"//tensorflow/core/kernels:parsing",
"//tensorflow/core/kernels:random_ops",
+ "//tensorflow/core/kernels:random_poisson_op",
"//tensorflow/core/kernels:remote_fused_graph_ops",
"//tensorflow/core/kernels:required",
"//tensorflow/core/kernels:resource_variable_ops",
diff --git a/tensorflow/core/graph/testlib.cc b/tensorflow/core/graph/testlib.cc
index ef4dd04787..f0ab5520f1 100644
--- a/tensorflow/core/graph/testlib.cc
+++ b/tensorflow/core/graph/testlib.cc
@@ -219,6 +219,16 @@ Node* RandomGamma(Graph* g, Node* shape, Node* alpha) {
return ret;
}
+Node* RandomPoisson(Graph* g, Node* shape, Node* lam) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "RandomPoisson")
+ .Input(shape)
+ .Input(lam)
+ .Attr("seed", 0)
+ .Finalize(g, &ret));
+ return ret;
+}
+
Node* Unary(Graph* g, const string& func, Node* input, int index) {
Node* ret;
TF_CHECK_OK(NodeBuilder(g->NewName("n"), func, g->op_registry())
diff --git a/tensorflow/core/graph/testlib.h b/tensorflow/core/graph/testlib.h
index 7a23b20c2c..d508f65ada 100644
--- a/tensorflow/core/graph/testlib.h
+++ b/tensorflow/core/graph/testlib.h
@@ -113,6 +113,10 @@ Node* RandomGaussian(Graph* g, Node* input, DataType dtype);
// Output dtype determined by alpha.
Node* RandomGamma(Graph* g, Node* shape, Node* alpha);
+// Generates random poisson distribution with the given shape and lam[s].
+// Output dtype determined by lam.
+Node* RandomPoisson(Graph* g, Node* shape, Node* lam);
+
// Generates random parameters from the truncated standard normal distribution
// of the nput shape
Node* TruncatedNormal(Graph* g, Node* input, DataType dtype);
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index ab5923b0e7..06a11d31ab 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -3362,6 +3362,33 @@ tf_cuda_cc_test(
)
tf_kernel_library(
+ name = "random_poisson_op",
+ prefix = "random_poisson_op",
+ deps = [
+ ":random_ops",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:random_ops_op_lib",
+ ],
+)
+
+tf_cuda_cc_test(
+ name = "random_poisson_op_test",
+ size = "small",
+ srcs = ["random_poisson_op_test.cc"],
+ deps = [
+ ":ops_util",
+ ":random_poisson_op",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
+tf_kernel_library(
name = "word2vec_kernels",
prefix = "word2vec_kernels",
deps = [
diff --git a/tensorflow/core/kernels/random_op.cc b/tensorflow/core/kernels/random_op.cc
index 0a1de11162..f3c7e0f26b 100644
--- a/tensorflow/core/kernels/random_op.cc
+++ b/tensorflow/core/kernels/random_op.cc
@@ -541,7 +541,7 @@ TF_CALL_int64(REGISTER_INT);
PhiloxRandomOp< \
GPUDevice, \
random::TruncatedNormalDistribution< \
- random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >)
+ random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >);
#define REGISTER_INT(IntType) \
REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \
diff --git a/tensorflow/core/kernels/random_poisson_op.cc b/tensorflow/core/kernels/random_poisson_op.cc
new file mode 100644
index 0000000000..553a4a7f93
--- /dev/null
+++ b/tensorflow/core/kernels/random_poisson_op.cc
@@ -0,0 +1,357 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// See docs in ../ops/random_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/kernels/random_poisson_op.h"
+
+#include <algorithm>
+#include <cmath>
+#include <memory>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/lib/random/random_distributions.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/util/guarded_philox_random.h"
+#include "tensorflow/core/util/work_sharder.h"
+
+#if EIGEN_COMP_GNUC && __cplusplus > 199711L
+#define DISABLE_FLOAT_EQUALITY_WARNING \
+ _Pragma("GCC diagnostic push") \
+ _Pragma("GCC diagnostic ignored \"-Wfloat-equal\"")
+#define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop")
+#else
+#define DISABLE_FLOAT_EQUALITY_WARNING
+#define ENABLE_FLOAT_EQUALITY_WARNING
+#endif
+
+#define UNIFORM(X) \
+ if (uniform_remaining == 0) { \
+ uniform_remaining = Uniform::kResultElementCount; \
+ uniform_result = uniform(&gen); \
+ } \
+ uniform_remaining--; \
+ CT X = uniform_result[uniform_remaining]
+
+namespace tensorflow {
+namespace {
+
+static constexpr int kReservedSamplesPerOutput = 256;
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+// We will compute half-precision Poisson samples with float precision
+// intermediate calculations.
+template <typename T>
+struct PoissonComputeType {
+ typedef T ComputeType;
+};
+
+template <>
+struct PoissonComputeType<Eigen::half> {
+ typedef float ComputeType;
+};
+
+} // namespace
+
+namespace functor {
+
+template <typename Device, typename T>
+struct PoissonFunctor {
+ void operator()(OpKernelContext* ctx, const Device& d, const T* rate_flat,
+ int num_rate, int num_samples,
+ const random::PhiloxRandom& rng, T* samples_flat);
+};
+
+template <typename T>
+struct PoissonFunctor<CPUDevice, T> {
+ void operator()(OpKernelContext* ctx, const CPUDevice& d, const T* rate_flat,
+ int num_rate, int num_samples,
+ const random::PhiloxRandom& rng, T* samples_flat) {
+ // Two different algorithms are employed, depending on the size of
+ // rate.
+ // If rate < 10, we use an algorithm attributed to Knuth:
+ // Seminumerical Algorithms. Art of Computer Programming, Volume 2.
+ //
+ // This algorithm runs in O(rate) time, and will require O(rate)
+ // uniform
+ // variates.
+ //
+ // If rate >= 10 we use a transformation-rejection algorithm from
+ // pairs
+ // of uniform random variables due to Hormann.
+ // http://www.sciencedirect.com/science/article/pii/0167668793909974
+ //
+ // The algorithm has an acceptance rate of ~89% for the smallest rate
+ // (~10),
+ // and higher accept rates for higher rate, so runtime is
+ // O(NumRate * NumSamples * k) with k ~ 1 / 0.89.
+ //
+ // We partition work first across rates then across
+ // samples-per-rate to
+ // avoid a couple flops which can be done on a per-rate basis.
+
+ typedef random::UniformDistribution<random::PhiloxRandom, CT> Uniform;
+
+ auto DoWork = [num_samples, num_rate, &rng, samples_flat, rate_flat](
+ int start_output, int limit_output) {
+ // Capturing "rng" by value would only make a copy for the _shared_
+ // lambda. Since we want to let each worker have its own copy, we pass
+ // "rng" by reference and explicitly do a copy assignment.
+
+ Uniform uniform;
+ typename Uniform::ResultType uniform_result;
+ for (int64 output_idx = start_output; output_idx < limit_output;
+ /* output_idx incremented within inner loop below */) {
+ const int64 rate_idx = output_idx / num_samples;
+
+ // Several calculations can be done on a per-rate basis.
+ const CT rate = CT(rate_flat[rate_idx]);
+
+ auto samples_rate_output = samples_flat + rate_idx;
+
+ if (rate < CT(10)) {
+ // Knuth's algorithm for generating Poisson random variates.
+ // Given a Poisson process, the time between events is exponentially
+ // distributed. If we have a Poisson process with rate lambda, then,
+ // the time between events is distributed Exp(lambda). If X ~
+ // Uniform(0, 1), then Y ~ Exp(lambda), where Y = -log(X) / lambda.
+ // Thus to simulate a Poisson draw, we can draw X_i ~ Exp(lambda),
+ // and N ~ Poisson(lambda), where N is the least number such that
+ // \sum_i^N X_i > 1.
+ const CT exp_neg_rate = Eigen::numext::exp(-rate);
+
+ // Compute the rest of the samples for the current rate value.
+ for (int64 sample_idx = output_idx % num_samples;
+ sample_idx < num_samples && output_idx < limit_output;
+ sample_idx++, output_idx++) {
+ random::PhiloxRandom gen = rng;
+ gen.Skip(kReservedSamplesPerOutput * output_idx);
+ int16 uniform_remaining = 0;
+
+ CT prod = 1;
+ CT x = 0;
+
+ // Keep trying until we surpass e^(-rate). This will take
+ // expected time proportional to rate.
+ while (true) {
+ UNIFORM(u);
+ prod = prod * u;
+ if (prod <= exp_neg_rate) {
+ samples_rate_output[sample_idx * num_rate] = T(x);
+ break;
+ }
+ x += 1;
+ }
+ }
+ continue;
+ }
+ // Transformed rejection due to Hormann.
+ //
+ // Given a CDF F(x), and G(x), a dominating distribution chosen such
+ // that it is close to the inverse CDF F^-1(x), compute the following
+ // steps:
+ //
+ // 1) Generate U and V, two independent random variates. Set U = U - 0.5
+ // (this step isn't strictly necessary, but is done to make some
+ // calculations symmetric and convenient. Henceforth, G is defined on
+ // [-0.5, 0.5]).
+ //
+ // 2) If V <= alpha * F'(G(U)) * G'(U), return floor(G(U)), else return
+ // to step 1. alpha is the acceptance probability of the rejection
+ // algorithm.
+ //
+ // For more details on transformed rejection, see:
+ // http://citeseer.ist.psu.edu/viewdoc/citations;jsessionid=1BEB35946CC807879F55D42512E5490C?doi=10.1.1.48.3054.
+ //
+ // The dominating distribution in this case:
+ //
+ // G(u) = (2 * a / (2 - |u|) + b) * u + c
+
+ using Eigen::numext::log;
+ const CT log_rate = log(rate);
+
+ // Constants used to define the dominating distribution. Names taken
+ // from Hormann's paper. Constants were chosen to define the tightest
+ // G(u) for the inverse Poisson CDF.
+ const CT b = CT(0.931) + CT(2.53) * Eigen::numext::sqrt(rate);
+ const CT a = CT(-0.059) + CT(0.02483) * b;
+
+ // This is the inverse acceptance rate. At a minimum (when rate = 10),
+ // this corresponds to ~75% acceptance. As the rate becomes larger, this
+ // approaches ~89%.
+ const CT inv_alpha = CT(1.1239) + CT(1.1328) / (b - CT(3.4));
+
+ // Compute the rest of the samples for the current rate value.
+ for (int64 sample_idx = output_idx % num_samples;
+ sample_idx < num_samples && output_idx < limit_output;
+ sample_idx++, output_idx++) {
+ random::PhiloxRandom gen = rng;
+ gen.Skip(kReservedSamplesPerOutput * output_idx);
+ int16 uniform_remaining = 0;
+
+ while (true) {
+ UNIFORM(u);
+ u -= CT(0.5);
+ UNIFORM(v);
+
+ CT u_shifted = CT(0.5) - Eigen::numext::abs(u);
+ CT k = Eigen::numext::floor((CT(2) * a / u_shifted + b) * u + rate +
+ CT(0.43));
+
+ // When alpha * f(G(U)) * G'(U) is close to 1, it is possible to
+ // find a rectangle (-u_r, u_r) x (0, v_r) under the curve, such
+ // that if v <= v_r and |u| <= u_r, then we can accept.
+ // Here v_r = 0.9227 - 3.6224 / (b - 2) and u_r = 0.43.
+ if (u_shifted >= CT(0.07) &&
+ v <= CT(0.9277) - CT(3.6224) / (b - CT(2))) {
+ samples_rate_output[sample_idx * num_rate] = T(k);
+ break;
+ }
+
+ if (k < 0 || (u_shifted < CT(0.013) && v > u_shifted)) {
+ continue;
+ }
+
+ // The expression below is equivalent to the computation of step 2)
+ // in transformed rejection (v <= alpha * F'(G(u)) * G'(u)).
+ CT s = log(v * inv_alpha / (a / (u_shifted * u_shifted) + b));
+ CT t = -rate + k * log_rate - Eigen::numext::lgamma(k + 1);
+ if (s <= t) {
+ samples_rate_output[sample_idx * num_rate] = T(k);
+ break;
+ }
+ }
+ }
+ }
+ };
+
+ // This will depend on rate.
+ // For rate < 10, on average, O(rate) calls to uniform are
+ // needed, with that
+ // many multiplies. ~10 uniform calls on average with ~25 cost op calls.
+ //
+ // Very roughly, for rate >= 10, the single call to log + call to
+ // lgamma
+ // occur for ~60 percent of samples.
+ // 2 x 100 (64-bit cycles per log) * 0.62 = ~124
+ // Additionally, there are ~10 other ops (+, *, /, ...) at 3-6 cycles each:
+ // 40 * .62 = ~25.
+ //
+ // Finally, there are several other ops that are done every loop along with
+ // 2 uniform generations along with 5 other ops at 3-6 cycles each.
+ // ~15 / .89 = ~16
+ //
+ // In total this should be ~165 + 2 * Uniform::kElementCost.
+ // We assume that half the tensor has rate < 10, so on average 6
+ // uniform's
+ // will be needed. We will upper bound the other op cost by the one for
+ // rate > 10.
+ static const int kElementCost = 165 + 6 * Uniform::kElementCost +
+ 6 * random::PhiloxRandom::kElementCost;
+ auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
+ Shard(worker_threads.num_threads, worker_threads.workers,
+ num_rate * num_samples, kElementCost, DoWork);
+ }
+
+ private:
+ typedef typename PoissonComputeType<T>::ComputeType CT;
+};
+
+} // namespace functor
+
+namespace {
+
+// Samples from one or more Poisson distributions.
+template <typename T>
+class RandomPoissonOp : public OpKernel {
+ public:
+ explicit RandomPoissonOp(OpKernelConstruction* context) : OpKernel(context) {
+ OP_REQUIRES_OK(context, generator_.Init(context));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& shape_t = ctx->input(0);
+ const Tensor& rate_t = ctx->input(1);
+
+ OP_REQUIRES(ctx,
+ TensorShapeUtils::IsVector(shape_t.shape()) &&
+ (shape_t.dtype() == DataType::DT_INT32 ||
+ shape_t.dtype() == DataType::DT_INT64),
+ errors::InvalidArgument(
+ "shape must be a vector of {int32,int64}, got shape: ",
+ shape_t.DebugString()));
+ TensorShape samples_shape;
+ if (shape_t.dtype() == DataType::DT_INT32) {
+ auto vec = shape_t.flat<int32>();
+ OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(vec.data(), vec.size(),
+ &samples_shape));
+ } else if (shape_t.dtype() == DataType::DT_INT64) {
+ auto vec = shape_t.flat<int64>();
+ OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(vec.data(), vec.size(),
+ &samples_shape));
+ }
+ const int64 num_samples = samples_shape.num_elements();
+ OP_REQUIRES(ctx, num_samples > 0,
+ errors::InvalidArgument(
+ "Input shape should have non-zero element count, got: ",
+ num_samples));
+
+ samples_shape.AppendShape(rate_t.shape());
+ // Allocate output samples.
+ Tensor* samples_t = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, samples_shape, &samples_t));
+
+ const auto rate_flat = rate_t.flat<T>().data();
+ const int64 num_rate = rate_t.NumElements();
+ OP_REQUIRES(
+ ctx, num_rate > 0,
+ errors::InvalidArgument(
+ "Input rate should have non-zero element count, got: ", num_rate));
+ auto samples_flat = samples_t->flat<T>().data();
+ random::PhiloxRandom rng = generator_.ReserveRandomOutputs(
+ num_samples * num_rate, kReservedSamplesPerOutput);
+
+ functor::PoissonFunctor<CPUDevice, T>()(ctx, ctx->eigen_device<CPUDevice>(),
+ rate_flat, num_rate, num_samples,
+ rng, samples_flat);
+ }
+
+ private:
+ GuardedPhiloxRandom generator_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(RandomPoissonOp);
+};
+} // namespace
+
+#undef UNIFORM
+
+#define REGISTER(TYPE) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("RandomPoisson").Device(DEVICE_CPU).TypeConstraint<TYPE>("dtype"), \
+ RandomPoissonOp<TYPE>);
+
+TF_CALL_half(REGISTER);
+TF_CALL_float(REGISTER);
+TF_CALL_double(REGISTER);
+
+#undef REGISTER
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/random_poisson_op.h b/tensorflow/core/kernels/random_poisson_op.h
new file mode 100644
index 0000000000..6c49acc800
--- /dev/null
+++ b/tensorflow/core/kernels/random_poisson_op.h
@@ -0,0 +1,31 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_KERNELS_RANDOM_POISSON_OP_H_
+#define TENSORFLOW_KERNELS_RANDOM_POISSON_OP_H_
+
+namespace tensorflow {
+
+namespace functor {
+
+// Generic helper functor for the Random Poisson Op.
+template <typename Device, typename T>
+struct PoissonFunctor;
+
+} // namespace functor
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_RANDOM_POISSON_OP_H_
diff --git a/tensorflow/core/kernels/random_poisson_op_test.cc b/tensorflow/core/kernels/random_poisson_op_test.cc
new file mode 100644
index 0000000000..bccdbf6c7f
--- /dev/null
+++ b/tensorflow/core/kernels/random_poisson_op_test.cc
@@ -0,0 +1,82 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <random>
+
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+namespace {
+
+Tensor VecShape(int64 v) {
+ if (v >= std::numeric_limits<int32>::max()) {
+ Tensor shape(DT_INT64, TensorShape({1}));
+ shape.vec<int64>()(0) = v;
+ return shape;
+ } else {
+ Tensor shape(DT_INT32, TensorShape({1}));
+ shape.vec<int32>()(0) = v;
+ return shape;
+ }
+}
+
+Tensor VecLam32(int64 n, int magnitude) {
+ std::mt19937 gen(0x12345);
+ std::uniform_real_distribution<float> dist(0.0, 1.0);
+ Tensor lams(DT_FLOAT, TensorShape({n}));
+ for (int i = 0; i < n; i++) {
+ // Generate in range (magnitude, 2 * magnitude)
+ lams.vec<float>()(i) = magnitude * (1 + dist(gen));
+ }
+ return lams;
+}
+
+Tensor VecLam64(int64 n, int magnitude) {
+ std::mt19937 gen(0x12345);
+ std::uniform_real_distribution<double> dist(0.0, 1.0);
+ Tensor lams(DT_DOUBLE, TensorShape({n}));
+ for (int i = 0; i < n; i++) {
+ // Generate in range (magnitude, 2 * magnitude)
+ lams.vec<double>()(i) = magnitude * (1 + dist(gen));
+ }
+ return lams;
+}
+
+#define BM_Poisson(DEVICE, BITS, MAGNITUDE) \
+ static void BM_##DEVICE##_RandomPoisson_lam_##MAGNITUDE##_##BITS( \
+ int iters, int nsamp, int nlam) { \
+ testing::ItemsProcessed(static_cast<int64>(iters) * nsamp * nlam); \
+ Graph* g = new Graph(OpRegistry::Global()); \
+ test::graph::RandomPoisson( \
+ g, test::graph::Constant(g, VecShape(nsamp)), \
+ test::graph::Constant(g, VecLam##BITS(nlam, MAGNITUDE))); \
+ test::Benchmark(#DEVICE, g).Run(iters); \
+ } \
+ BENCHMARK(BM_##DEVICE##_RandomPoisson_lam_##MAGNITUDE##_##BITS) \
+ ->RangePair(1, 64, 2, 50);
+
+BM_Poisson(cpu, 32, 1);
+BM_Poisson(cpu, 32, 8);
+BM_Poisson(cpu, 32, 32);
+
+BM_Poisson(cpu, 64, 1);
+BM_Poisson(cpu, 64, 8);
+BM_Poisson(cpu, 64, 32);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/random_ops.cc b/tensorflow/core/ops/random_ops.cc
index 776523f33f..7b2da9d8e6 100644
--- a/tensorflow/core/ops/random_ops.cc
+++ b/tensorflow/core/ops/random_ops.cc
@@ -276,4 +276,48 @@ output: A tensor with shape `shape + shape(alpha)`. Each slice
`alpha[i0, i1, ...iN]`. The dtype of the output matches the dtype of alpha.
)doc");
+REGISTER_OP("RandomPoisson")
+ .SetIsStateful()
+ .Input("shape: S")
+ .Input("rate: dtype")
+ .Output("output: dtype")
+ .Attr("seed: int = 0")
+ .Attr("seed2: int = 0")
+ .Attr("S: {int32, int64}")
+ .Attr("dtype: {half, float, double}")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle out;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
+ TF_RETURN_IF_ERROR(c->Concatenate(out, c->input(1), &out));
+ c->set_output(0, out);
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Outputs random values from the Poisson distribution(s) described by rate.
+
+This op uses two algorithms, depending on rate. If rate >= 10, then
+the algorithm by Hormann is used to acquire samples via
+transformation-rejection.
+See http://www.sciencedirect.com/science/article/pii/0167668793909974.
+
+Otherwise, Knuth's algorithm is used to acquire samples via multiplying uniform
+random variables.
+See Donald E. Knuth (1969). Seminumerical Algorithms. The Art of Computer
+Programming, Volume 2. Addison Wesley
+
+shape: 1-D integer tensor. Shape of independent samples to draw from each
+ distribution described by the shape parameters given in rate.
+rate: A tensor in which each scalar is a "rate" parameter describing the
+ associated poisson distribution.
+seed: If either `seed` or `seed2` are set to be non-zero, the random number
+ generator is seeded by the given seed. Otherwise, it is seeded by a
+ random seed.
+seed2: A second seed to avoid seed collision.
+
+output: A tensor with shape `shape + shape(rate)`. Each slice
+ `[:, ..., :, i0, i1, ...iN]` contains the samples drawn for
+ `rate[i0, i1, ...iN]`. The dtype of the output matches the dtype of
+ rate.
+)doc");
+
} // namespace tensorflow
diff --git a/tensorflow/core/ops/random_ops_test.cc b/tensorflow/core/ops/random_ops_test.cc
index 524e107998..b0aa565485 100644
--- a/tensorflow/core/ops/random_ops_test.cc
+++ b/tensorflow/core/ops/random_ops_test.cc
@@ -53,4 +53,19 @@ TEST(RandomOpsTest, RandomGamma_ShapeFn) {
INFER_OK(op, "[3];[]", "[1,2,3]");
}
+TEST(RandomOpsTest, RandomPoisson_ShapeFn) {
+ ShapeInferenceTestOp op("RandomPoisson");
+ op.input_tensors.resize(2);
+
+ INFER_OK(op, "?;?", "?");
+ INFER_OK(op, "?;[3]", "?");
+ INFER_OK(op, "[1];?", "?");
+ INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[1,2];[3,4]");
+ Tensor shape = test::AsTensor<int32>({1, 2, 3});
+ op.input_tensors[0] = &shape;
+ INFER_OK(op, "[3];[4,?]", "[1,2,3,d1_0,d1_1]");
+ INFER_OK(op, "[3];[4,5]", "[1,2,3,d1_0,d1_1]");
+ INFER_OK(op, "[3];[]", "[1,2,3]");
+}
+
} // end namespace tensorflow