aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/stateless_random_ops.cc
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-11-27 14:06:23 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-27 14:10:30 -0800
commit6cc7e387fc1b642d363b6a18877a411382a82fa5 (patch)
tree2a35fee29ac8b979908e8323b24a8ec2cb612d8d /tensorflow/core/kernels/stateless_random_ops.cc
parentdb9533e4f5fa940f704996cd6d38f40b13d40dff (diff)
[TF:XLA] Implement StatelessRandomUniform and StatelessRandomNormal using the ThreeFry counter-based PRNG.
Extend stateless ops to allow 32-bit integer seeds, with a 64-bit default. PiperOrigin-RevId: 177068747
Diffstat (limited to 'tensorflow/core/kernels/stateless_random_ops.cc')
-rw-r--r--tensorflow/core/kernels/stateless_random_ops.cc15
1 files changed, 12 insertions, 3 deletions
diff --git a/tensorflow/core/kernels/stateless_random_ops.cc b/tensorflow/core/kernels/stateless_random_ops.cc
index f6fb0a121d..88fcf542fb 100644
--- a/tensorflow/core/kernels/stateless_random_ops.cc
+++ b/tensorflow/core/kernels/stateless_random_ops.cc
@@ -50,9 +50,18 @@ class StatelessRandomOpBase : public OpKernel {
if (shape.num_elements() == 0) return;
// Grab the two seeds
- const auto seed = seed_t.flat<int64>();
- const uint64 seed0 = internal::SubtleMustCopy(seed(0));
- const uint64 seed1 = internal::SubtleMustCopy(seed(1));
+ uint64 seed0;
+ uint64 seed1;
+ if (context->input_dtype(1) == DT_INT32) {
+ const auto seed = seed_t.flat<int32>();
+ seed0 = internal::SubtleMustCopy(seed(0));
+ seed1 = internal::SubtleMustCopy(seed(1));
+ } else {
+ CHECK_EQ(DT_INT64, context->input_dtype(1));
+ const auto seed = seed_t.flat<int64>();
+ seed0 = internal::SubtleMustCopy(seed(0));
+ seed1 = internal::SubtleMustCopy(seed(1));
+ }
// Scramble the seeds so that the user doesn't need to worry about which
// part of the seed needs to be strong.