diff options
author | 2017-11-27 14:06:23 -0800 | |
---|---|---|
committer | 2017-11-27 14:10:30 -0800 | |
commit | 6cc7e387fc1b642d363b6a18877a411382a82fa5 (patch) | |
tree | 2a35fee29ac8b979908e8323b24a8ec2cb612d8d /tensorflow/core/kernels/stateless_random_ops.cc | |
parent | db9533e4f5fa940f704996cd6d38f40b13d40dff (diff) |
[TF:XLA] Implement StatelessRandomUniform and StatelessRandomNormal using the ThreeFry counter-based PRNG.
Extend stateless ops to allow 32-bit integer seeds, with a 64-bit default.
PiperOrigin-RevId: 177068747
Diffstat (limited to 'tensorflow/core/kernels/stateless_random_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/stateless_random_ops.cc | 15 |
1 files changed, 12 insertions, 3 deletions
diff --git a/tensorflow/core/kernels/stateless_random_ops.cc b/tensorflow/core/kernels/stateless_random_ops.cc index f6fb0a121d..88fcf542fb 100644 --- a/tensorflow/core/kernels/stateless_random_ops.cc +++ b/tensorflow/core/kernels/stateless_random_ops.cc @@ -50,9 +50,18 @@ class StatelessRandomOpBase : public OpKernel { if (shape.num_elements() == 0) return; // Grab the two seeds - const auto seed = seed_t.flat<int64>(); - const uint64 seed0 = internal::SubtleMustCopy(seed(0)); - const uint64 seed1 = internal::SubtleMustCopy(seed(1)); + uint64 seed0; + uint64 seed1; + if (context->input_dtype(1) == DT_INT32) { + const auto seed = seed_t.flat<int32>(); + seed0 = internal::SubtleMustCopy(seed(0)); + seed1 = internal::SubtleMustCopy(seed(1)); + } else { + CHECK_EQ(DT_INT64, context->input_dtype(1)); + const auto seed = seed_t.flat<int64>(); + seed0 = internal::SubtleMustCopy(seed(0)); + seed1 = internal::SubtleMustCopy(seed(1)); + } // Scramble the seeds so that the user doesn't need to worry about which // part of the seed needs to be strong. |