diff options
author | Geoffrey Irving <irving@naml.us> | 2018-09-19 09:33:19 -0700 |
---|---|---|
committer | Geoffrey Irving <irving@naml.us> | 2018-10-05 15:02:11 -0700 |
commit | 1e104d80826fed95f9fad6f07f68e35cae3527b2 (patch) | |
tree | bb934f3e78f55f53fadbdba8408e550f1f61e171 | |
parent | 80c9eec9b2475630f83a596f77a906c8075f8e6c (diff) |
Expand stateless random generators to match their stateful cousins
stateless_random_uniform now take minval+maxval and handles ints,
and stateless_normal/stateless_truncated_normal take mean+stddev.
Additionally, all of the stateless functions now have proper doc
strings.
This is step one of moving stateless random numbers out of contrib.
-rw-r--r-- | tensorflow/contrib/stateless/BUILD | 5 | ||||
-rw-r--r-- | tensorflow/contrib/stateless/__init__.py | 9 | ||||
-rw-r--r-- | tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py | 156 | ||||
-rw-r--r-- | tensorflow/contrib/stateless/python/stateless_ops.py | 214 | ||||
-rw-r--r-- | tensorflow/core/api_def/base_api/api_def_StatelessRandomUniformInt.pbtxt | 46 | ||||
-rw-r--r-- | tensorflow/core/kernels/random_op.cc | 34 | ||||
-rw-r--r-- | tensorflow/core/kernels/stateless_random_ops.cc | 155 | ||||
-rw-r--r-- | tensorflow/core/ops/stateless_random_ops.cc | 53 |
8 files changed, 491 insertions, 181 deletions
diff --git a/tensorflow/contrib/stateless/BUILD b/tensorflow/contrib/stateless/BUILD index a217397c1a..e9ddec8889 100644 --- a/tensorflow/contrib/stateless/BUILD +++ b/tensorflow/contrib/stateless/BUILD @@ -11,7 +11,10 @@ load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") py_library( name = "stateless", - srcs = ["__init__.py"], + srcs = [ + "__init__.py", + "python/stateless_ops.py", + ], srcs_version = "PY2AND3", deps = [ "//tensorflow/python:framework_ops", diff --git a/tensorflow/contrib/stateless/__init__.py b/tensorflow/contrib/stateless/__init__.py index fe23fe0dd8..30d0a7ab6a 100644 --- a/tensorflow/contrib/stateless/__init__.py +++ b/tensorflow/contrib/stateless/__init__.py @@ -32,16 +32,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.framework import ops - # pylint: disable=wildcard-import -from tensorflow.python.ops.gen_stateless_random_ops import * +from tensorflow.contrib.stateless.python.stateless_ops import * from tensorflow.python.util.all_util import remove_undocumented -ops.NotDifferentiable("StatelessMultinomial") -ops.NotDifferentiable("StatelessRandomNormal") -ops.NotDifferentiable("StatelessRandomUniform") -ops.NotDifferentiable("StatelessTruncatedNormal") - remove_undocumented(__name__) diff --git a/tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py b/tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py index d724a5c014..c0c1430d84 100644 --- a/tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py +++ b/tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools + import numpy as np from tensorflow.contrib import stateless from tensorflow.python.framework import constant_op @@ -27,10 +29,6 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import random_ops from tensorflow.python.platform import test -CASES = [(stateless.stateless_random_uniform, random_ops.random_uniform), - (stateless.stateless_random_normal, random_ops.random_normal), - (stateless.stateless_truncated_normal, random_ops.truncated_normal)] - def invert_philox(key, value): """Invert the Philox bijection.""" @@ -51,96 +49,102 @@ def invert_philox(key, value): class StatelessOpsTest(test.TestCase): - def testMatchStateful(self): + def _test_match(self, cases): # Stateless ops should be the same as stateful ops on the first call # after seed scrambling. + cases = tuple(cases) key = 0x3ec8f720, 0x02461e29 for seed in (7, 17), (11, 5), (2, 3): preseed = invert_philox(key, (seed[0], 0, seed[1], 0)).astype(np.uint64) preseed = preseed[::2] | preseed[1::2] << 32 random_seed.set_random_seed(seed[0]) with self.test_session(use_gpu=True): - for stateless_op, stateful_op in CASES: - for shape in (), (3,), (2, 5): - stateful = stateful_op(shape, seed=seed[1]) - pure = stateless_op(shape, seed=preseed) - self.assertAllEqual(stateful.eval(), pure.eval()) + for stateless_op, stateful_op in cases: + stateful = stateful_op(seed=seed[1]) + pure = stateless_op(seed=preseed) + self.assertAllEqual(stateful.eval(), pure.eval()) - def testDeterminism(self): + def _test_determinism(self, cases): # Stateless values should be equal iff the seeds are equal (roughly) + cases = tuple(cases) with self.test_session(use_gpu=True): for seed_type in [dtypes.int32, dtypes.int64]: seed_t = array_ops.placeholder(seed_type, shape=[2]) seeds = [(x, y) for x in range(5) for y in range(5)] * 3 - for stateless_op, _ in CASES: - for shape in (), (3,), (2, 5): - pure = stateless_op(shape, seed=seed_t) - values = [(seed, pure.eval(feed_dict={seed_t: seed})) - for seed in seeds] - for s0, v0 in values: - for s1, v1 in values: - self.assertEqual(s0 == s1, np.all(v0 == v1)) - - def testShapeType(self): - with self.test_session(use_gpu=True): - for shape_dtype in [dtypes.int32, dtypes.int64]: - seed_t = array_ops.placeholder(dtypes.int64, shape=[2]) - seeds = [(x, y) for x in range(5) for y in range(5)] * 3 - for stateless_op, _ in CASES: - for shape in (), (3,), (2, 5): - pure = stateless_op(constant_op.constant(shape, dtype=shape_dtype), - seed=seed_t) - values = [(seed, pure.eval(feed_dict={seed_t: seed})) - for seed in seeds] - for s0, v0 in values: - for s1, v1 in values: - self.assertEqual(s0 == s1, np.all(v0 == v1)) - - def testMatchStatefulMultinomial(self): - # Stateless ops should be the same as stateful ops on the first call - # after seed scrambling. - key = 0x3ec8f720, 0x02461e29 - num_samples = 4 - for logits_dtype in np.float16, np.float32, np.float64: - for output_dtype in dtypes.int32, dtypes.int64: - for seed in (7, 17), (11, 5), (2, 3): - preseed = invert_philox(key, - (seed[0], 0, seed[1], 0)).astype(np.uint64) - preseed = preseed[::2] | preseed[1::2] << 32 - random_seed.set_random_seed(seed[0]) - with self.test_session(use_gpu=True): - for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5], [0.8, 0.2], - [0.25, 0.75]]): - logits_t = constant_op.constant(logits, dtype=logits_dtype) - stateful = random_ops.multinomial( - logits_t, - num_samples, - seed=seed[1], - output_dtype=output_dtype) - pure = stateless.stateless_multinomial( - logits_t, - num_samples, - seed=preseed, - output_dtype=output_dtype) - self.assertAllEqual(stateful.eval(), pure.eval()) + for stateless_op, _ in cases: + pure = stateless_op(seed=seed_t) + values = [(seed, pure.eval(feed_dict={seed_t: seed})) + for seed in seeds] + for s0, v0 in values: + for s1, v1 in values: + self.assertEqual(s0 == s1, np.all(v0 == v1)) - def testDeterminismMultinomial(self): - # Stateless values should be equal iff the seeds are equal (roughly) + def _float_cases(self, shape_dtypes=(None,)): + float_cases = ( + # Uniform distribution, with and without range + (stateless.stateless_random_uniform, random_ops.random_uniform, {}), + (stateless.stateless_random_uniform, random_ops.random_uniform, + dict(minval=2.2, maxval=7.1)), + # Normal distribution, with and without mean+stddev + (stateless.stateless_random_normal, random_ops.random_normal, {}), + (stateless.stateless_random_normal, random_ops.random_normal, + dict(mean=2, stddev=3)), + # Truncated normal distribution, with and without mean+stddev + (stateless.stateless_truncated_normal, random_ops.truncated_normal, {}), + (stateless.stateless_truncated_normal, random_ops.truncated_normal, + dict(mean=3, stddev=4)), + ) + for dtype in dtypes.float16, dtypes.float32, dtypes.float64: + for shape_dtype in shape_dtypes: + for shape in (), (3,), (2, 5): + if shape_dtype is not None: + shape = constant_op.constant(shape, dtype=shape_dtype) + for stateless_op, stateful_op, kwds in float_cases: + kwds = dict(shape=shape, dtype=dtype, **kwds) + yield (functools.partial(stateless_op, **kwds), + functools.partial(stateful_op, **kwds)) + + def _int_cases(self, shape_dtypes=(None,)): + for shape_dtype in shape_dtypes: + for shape in (), (3,), (2, 5): + if shape_dtype is not None: + shape = constant_op.constant(shape, dtype=shape_dtype) + for dtype in dtypes.int32, dtypes.int64: + kwds = dict(minval=2, maxval=11111, dtype=dtype, shape=shape) + yield (functools.partial(stateless.stateless_random_uniform, **kwds), + functools.partial(random_ops.random_uniform, **kwds)) + + def _multinomial_cases(self): num_samples = 10 - with self.test_session(use_gpu=True): - for seed_type in [dtypes.int32, dtypes.int64]: - seed_t = array_ops.placeholder(seed_type, shape=[2]) - seeds = [(x, y) for x in range(5) for y in range(5)] * 3 + for logits_dtype in np.float16, np.float32, np.float64: + for output_dtype in dtypes.int32, dtypes.int64: for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5], [0.8, 0.2], [0.25, 0.75]]): - pure = stateless.stateless_multinomial( - logits, num_samples, seed=seed_t) - values = [ - (seed, pure.eval(feed_dict={seed_t: seed})) for seed in seeds - ] - for s0, v0 in values: - for s1, v1 in values: - self.assertEqual(s0 == s1, np.all(v0 == v1)) + kwds = dict(logits=constant_op.constant(logits, dtype=logits_dtype), + num_samples=num_samples, + output_dtype=output_dtype) + yield (functools.partial(stateless.stateless_multinomial, **kwds), + functools.partial(random_ops.multinomial, **kwds)) + + def testMatchFloat(self): + self._test_match(self._float_cases()) + + def testMatchInt(self): + self._test_match(self._int_cases()) + + def testMatchMultinomial(self): + self._test_match(self._multinomial_cases()) + + def testDeterminismFloat(self): + self._test_determinism(self._float_cases( + shape_dtypes=(dtypes.int32, dtypes.int64))) + + def testDeterminismInt(self): + self._test_determinism(self._int_cases( + shape_dtypes=(dtypes.int32, dtypes.int64))) + + def testDeterminismMultinomial(self): + self._test_determinism(self._multinomial_cases()) if __name__ == '__main__': diff --git a/tensorflow/contrib/stateless/python/stateless_ops.py b/tensorflow/contrib/stateless/python/stateless_ops.py new file mode 100644 index 0000000000..db9b7a87f2 --- /dev/null +++ b/tensorflow/contrib/stateless/python/stateless_ops.py @@ -0,0 +1,214 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Stateless random ops which take seed as a tensor input.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.ops import gen_stateless_random_ops + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import math_ops + +ops.NotDifferentiable("StatelessMultinomial") +ops.NotDifferentiable("StatelessRandomNormal") +ops.NotDifferentiable("StatelessRandomUniform") +ops.NotDifferentiable("StatelessRandomUniformInt") +ops.NotDifferentiable("StatelessTruncatedNormal") + + +def stateless_random_uniform(shape, + seed, + minval=0, + maxval=None, + dtype=dtypes.float32, + name=None): + """Outputs deterministic pseudorandom values from a uniform distribution. + + This is a stateless version of `tf.random_uniform`: if run twice with the + same seeds, it will produce the same pseudorandom numbers. The output is + consistent across multiple runs on the same hardware (and between CPU + and GPU), but may change between versions of TensorFlow or on non-CPU/GPU + hardware. + + The generated values follow a uniform distribution in the range + `[minval, maxval)`. The lower bound `minval` is included in the range, while + the upper bound `maxval` is excluded. + + For floats, the default range is `[0, 1)`. For ints, at least `maxval` must + be specified explicitly. + + In the integer case, the random integers are slightly biased unless + `maxval - minval` is an exact power of two. The bias is small for values of + `maxval - minval` significantly smaller than the range of the output (either + `2**32` or `2**64`). + + Args: + shape: A 1-D integer Tensor or Python array. The shape of the output tensor. + seed: A shape [2] integer Tensor of seeds to the random number generator. + minval: A 0-D Tensor or Python value of type `dtype`. The lower bound on the + range of random values to generate. Defaults to 0. + maxval: A 0-D Tensor or Python value of type `dtype`. The upper bound on + the range of random values to generate. Defaults to 1 if `dtype` is + floating point. + dtype: The type of the output: `float16`, `float32`, `float64`, `int32`, + or `int64`. + name: A name for the operation (optional). + + Returns: + A tensor of the specified shape filled with random uniform values. + + Raises: + ValueError: If `dtype` is integral and `maxval` is not specified. + """ + dtype = dtypes.as_dtype(dtype) + if dtype not in (dtypes.float16, dtypes.bfloat16, dtypes.float32, + dtypes.float64, dtypes.int32, dtypes.int64): + raise ValueError("Invalid dtype %r" % dtype) + if maxval is None: + if dtype.is_integer: + raise ValueError("Must specify maxval for integer dtype %r" % dtype) + maxval = 1 + with ops.name_scope(name, "stateless_random_uniform", + [shape, seed, minval, maxval]) as name: + shape = random_ops._ShapeTensor(shape) # pylint: disable=protected-access + minval = ops.convert_to_tensor(minval, dtype=dtype, name="min") + maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max") + if dtype.is_integer: + return gen_stateless_random_ops.stateless_random_uniform_int( + shape, seed=seed, minval=minval, maxval=maxval, name=name) + else: + rnd = gen_stateless_random_ops.stateless_random_uniform( + shape, seed=seed, dtype=dtype) + return math_ops.add(rnd * (maxval - minval), minval, name=name) + + +def stateless_random_normal(shape, + seed, + mean=0.0, + stddev=1.0, + dtype=dtypes.float32, + name=None): + """Outputs deterministic pseudorandom values from a normal distribution. + + This is a stateless version of `tf.random_normal`: if run twice with the + same seeds, it will produce the same pseudorandom numbers. The output is + consistent across multiple runs on the same hardware (and between CPU + and GPU), but may change between versions of TensorFlow or on non-CPU/GPU + hardware. + + Args: + shape: A 1-D integer Tensor or Python array. The shape of the output tensor. + seed: A shape [2] integer Tensor of seeds to the random number generator. + mean: A 0-D Tensor or Python value of type `dtype`. The mean of the normal + distribution. + stddev: A 0-D Tensor or Python value of type `dtype`. The standard deviation + of the normal distribution. + dtype: The type of the output. + name: A name for the operation (optional). + + Returns: + A tensor of the specified shape filled with random normal values. + """ + with ops.name_scope(name, "stateless_random_normal", + [shape, seed, mean, stddev]) as name: + shape = random_ops._ShapeTensor(shape) # pylint: disable=protected-access + mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean") + stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev") + rnd = gen_stateless_random_ops.stateless_random_normal(shape, seed, dtype) + return math_ops.add(rnd * stddev, mean, name=name) + + +def stateless_truncated_normal(shape, + seed, + mean=0.0, + stddev=1.0, + dtype=dtypes.float32, + name=None): + """Outputs deterministic pseudorandom values, truncated normally distributed. + + This is a stateless version of `tf.truncated_normal`: if run twice with the + same seeds, it will produce the same pseudorandom numbers. The output is + consistent across multiple runs on the same hardware (and between CPU + and GPU), but may change between versions of TensorFlow or on non-CPU/GPU + hardware. + + The generated values follow a normal distribution with specified mean and + standard deviation, except that values whose magnitude is more than 2 standard + deviations from the mean are dropped and re-picked. + + Args: + shape: A 1-D integer Tensor or Python array. The shape of the output tensor. + seed: A shape [2] integer Tensor of seeds to the random number generator. + mean: A 0-D Tensor or Python value of type `dtype`. The mean of the + truncated normal distribution. + stddev: A 0-D Tensor or Python value of type `dtype`. The standard deviation + of the normal distribution, before truncation. + dtype: The type of the output. + name: A name for the operation (optional). + + Returns: + A tensor of the specified shape filled with random truncated normal values. + """ + with ops.name_scope(name, "stateless_truncated_normal", + [shape, seed, mean, stddev]) as name: + shape = random_ops._ShapeTensor(shape) # pylint: disable=protected-access + mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean") + stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev") + rnd = gen_stateless_random_ops.stateless_truncated_normal( + shape, seed, dtype) + return math_ops.add(rnd * stddev, mean, name=name) + + +def stateless_multinomial(logits, + num_samples, + seed, + output_dtype=dtypes.int64, + name=None): + """Draws deterministic pseudorandom samples from a multinomial distribution. + + This is a stateless version of `tf.multinomial`: if run twice with the + same seeds, it will produce the same pseudorandom numbers. The output is + consistent across multiple runs on the same hardware (and between CPU + and GPU), but may change between versions of TensorFlow or on non-CPU/GPU + hardware. + + Example: + + ```python + # samples has shape [1, 5], where each value is either 0 or 1 with equal + # probability. + samples = tf.contrib.stateless.stateless_multinomial( + tf.log([[10., 10.]]), 5, seed=[7, 17]) + ``` + + Args: + logits: 2-D Tensor with shape `[batch_size, num_classes]`. Each slice + `[i, :]` represents the unnormalized log-probabilities for all classes. + num_samples: 0-D. Number of independent samples to draw for each row slice. + seed: A shape [2] integer Tensor of seeds to the random number generator. + name: Optional name for the operation. + output_dtype: integer type to use for the output. Defaults to int64. + + Returns: + The drawn samples of shape `[batch_size, num_samples]`. + """ + with ops.name_scope(name, "stateless_multinomial", [logits, seed]): + logits = ops.convert_to_tensor(logits, name="logits") + return gen_stateless_random_ops.stateless_multinomial( + logits, num_samples, seed, output_dtype=output_dtype) diff --git a/tensorflow/core/api_def/base_api/api_def_StatelessRandomUniformInt.pbtxt b/tensorflow/core/api_def/base_api/api_def_StatelessRandomUniformInt.pbtxt new file mode 100644 index 0000000000..b6a6dbdf54 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_StatelessRandomUniformInt.pbtxt @@ -0,0 +1,46 @@ +op { + graph_op_name: "StatelessRandomUniformInt" + visibility: HIDDEN + in_arg { + name: "shape" + description: <<END +The shape of the output tensor. +END + } + in_arg { + name: "seed" + description: <<END +2 seeds (shape [2]). +END + } + in_arg { + name: "minval" + description: <<END +Minimum value (inclusive, scalar). +END + } + in_arg { + name: "maxval" + description: <<END +Maximum value (exclusive, scalar). +END + } + out_arg { + name: "output" + description: <<END +Random values with specified shape. +END + } + attr { + name: "dtype" + description: <<END +The type of the output. +END + } + summary: "Outputs deterministic pseudorandom random integers from a uniform distribution." + description: <<END +The generated values follow a uniform distribution in the range `[minval, maxval)`. + +The outputs are a deterministic function of `shape`, `seed`, `minval`, and `maxval`. +END +} diff --git a/tensorflow/core/kernels/random_op.cc b/tensorflow/core/kernels/random_op.cc index 04a53697c0..3810d817ca 100644 --- a/tensorflow/core/kernels/random_op.cc +++ b/tensorflow/core/kernels/random_op.cc @@ -489,13 +489,15 @@ class RandomGammaOp : public OpKernel { Name("RandomGamma").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \ RandomGammaOp<TYPE>) -#define REGISTER_INT(IntType) \ - REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \ - .Device(DEVICE_CPU) \ - .HostMemory("shape") \ - .HostMemory("minval") \ - .HostMemory("maxval") \ - .TypeConstraint<IntType>("Tout"), \ +#define REGISTER_INT(IntType) \ + template struct functor::FillPhiloxRandom< \ + CPUDevice, random::UniformDistribution<random::PhiloxRandom, IntType>>; \ + REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \ + .Device(DEVICE_CPU) \ + .HostMemory("shape") \ + .HostMemory("minval") \ + .HostMemory("maxval") \ + .TypeConstraint<IntType>("Tout"), \ RandomUniformIntOp<CPUDevice, IntType>); TF_CALL_half(REGISTER); @@ -538,14 +540,16 @@ TF_CALL_int64(REGISTER_INT); random::TruncatedNormalDistribution< \ random::SingleSampleAdapter<random::PhiloxRandom>, TYPE>>); -#define REGISTER_INT(IntType) \ - REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \ - .Device(DEVICE_GPU) \ - .HostMemory("shape") \ - .HostMemory("minval") \ - .HostMemory("maxval") \ - .TypeConstraint<int32>("T") \ - .TypeConstraint<IntType>("Tout"), \ +#define REGISTER_INT(IntType) \ + template struct functor::FillPhiloxRandom< \ + GPUDevice, random::UniformDistribution<random::PhiloxRandom, IntType>>; \ + REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \ + .Device(DEVICE_GPU) \ + .HostMemory("shape") \ + .HostMemory("minval") \ + .HostMemory("maxval") \ + .TypeConstraint<int32>("T") \ + .TypeConstraint<IntType>("Tout"), \ RandomUniformIntOp<GPUDevice, IntType>); TF_CALL_half(REGISTER); diff --git a/tensorflow/core/kernels/stateless_random_ops.cc b/tensorflow/core/kernels/stateless_random_ops.cc index eab176c7fb..925f5291a6 100644 --- a/tensorflow/core/kernels/stateless_random_ops.cc +++ b/tensorflow/core/kernels/stateless_random_ops.cc @@ -113,74 +113,109 @@ class StatelessRandomOp : public StatelessRandomOpBase { } }; -#define REGISTER(TYPE) \ - REGISTER_KERNEL_BUILDER( \ - Name("StatelessRandomUniform") \ - .Device(DEVICE_CPU) \ - .HostMemory("shape") \ - .TypeConstraint<TYPE>("dtype"), \ - StatelessRandomOp<CPUDevice, random::UniformDistribution< \ - random::PhiloxRandom, TYPE> >); \ - REGISTER_KERNEL_BUILDER( \ - Name("StatelessRandomNormal") \ - .Device(DEVICE_CPU) \ - .HostMemory("shape") \ - .TypeConstraint<TYPE>("dtype"), \ - StatelessRandomOp<CPUDevice, random::NormalDistribution< \ - random::PhiloxRandom, TYPE> >); \ - REGISTER_KERNEL_BUILDER( \ - Name("StatelessTruncatedNormal") \ - .Device(DEVICE_CPU) \ - .HostMemory("shape") \ - .TypeConstraint<TYPE>("dtype"), \ - StatelessRandomOp< \ - CPUDevice, \ - random::TruncatedNormalDistribution< \ - random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >); +template <typename Device, typename IntType> +class StatelessRandomUniformIntOp : public StatelessRandomOpBase { + public: + using StatelessRandomOpBase::StatelessRandomOpBase; -TF_CALL_half(REGISTER); -TF_CALL_float(REGISTER); -TF_CALL_double(REGISTER); + void Fill(OpKernelContext* context, random::PhiloxRandom random, + Tensor* output) override { + const Tensor& minval = context->input(2); + const Tensor& maxval = context->input(3); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(minval.shape()), + errors::InvalidArgument("minval must be 0-D, got shape ", + minval.shape().DebugString())); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(maxval.shape()), + errors::InvalidArgument("maxval must be 0-D, got shape ", + maxval.shape().DebugString())); + + // Verify that minval < maxval. Note that we'll never reach this point for + // empty output. Zero impossible things are fine. + const auto lo = minval.scalar<IntType>()(); + const auto hi = maxval.scalar<IntType>()(); + OP_REQUIRES( + context, lo < hi, + errors::InvalidArgument("Need minval < maxval, got ", lo, " >= ", hi)); + + // Build distribution + typedef random::UniformDistribution<random::PhiloxRandom, IntType> + Distribution; + Distribution dist(lo, hi); + + auto flat = output->flat<IntType>(); + // Reuse the compute kernels from the stateful random ops + functor::FillPhiloxRandom<Device, Distribution>()( + context, context->eigen_device<Device>(), random, flat.data(), + flat.size(), dist); + } +}; -#undef REGISTER +#define REGISTER(DEVICE, TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("StatelessRandomUniform") \ + .Device(DEVICE_##DEVICE) \ + .HostMemory("shape") \ + .HostMemory("seed") \ + .TypeConstraint<TYPE>("dtype"), \ + StatelessRandomOp<DEVICE##Device, random::UniformDistribution< \ + random::PhiloxRandom, TYPE> >); \ + REGISTER_KERNEL_BUILDER( \ + Name("StatelessRandomNormal") \ + .Device(DEVICE_##DEVICE) \ + .HostMemory("shape") \ + .HostMemory("seed") \ + .TypeConstraint<TYPE>("dtype"), \ + StatelessRandomOp<DEVICE##Device, random::NormalDistribution< \ + random::PhiloxRandom, TYPE> >); \ + REGISTER_KERNEL_BUILDER( \ + Name("StatelessTruncatedNormal") \ + .Device(DEVICE_##DEVICE) \ + .HostMemory("shape") \ + .HostMemory("seed") \ + .TypeConstraint<TYPE>("dtype"), \ + StatelessRandomOp< \ + DEVICE##Device, \ + random::TruncatedNormalDistribution< \ + random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >); + +#define REGISTER_INT(DEVICE, TYPE) \ + REGISTER_KERNEL_BUILDER(Name("StatelessRandomUniformInt") \ + .Device(DEVICE_##DEVICE) \ + .HostMemory("shape") \ + .HostMemory("seed") \ + .HostMemory("minval") \ + .HostMemory("maxval") \ + .TypeConstraint<TYPE>("dtype"), \ + StatelessRandomUniformIntOp<DEVICE##Device, TYPE>); + +#define REGISTER_CPU(TYPE) REGISTER(CPU, TYPE) +#define REGISTER_GPU(TYPE) REGISTER(GPU, TYPE) +#define REGISTER_INT_CPU(TYPE) REGISTER_INT(CPU, TYPE) +#define REGISTER_INT_GPU(TYPE) REGISTER_INT(GPU, TYPE) + +TF_CALL_half(REGISTER_CPU); +TF_CALL_bfloat16(REGISTER_CPU); +TF_CALL_float(REGISTER_CPU); +TF_CALL_double(REGISTER_CPU); +TF_CALL_int32(REGISTER_INT_CPU); +TF_CALL_int64(REGISTER_INT_CPU); #if GOOGLE_CUDA -#define REGISTER(TYPE) \ - REGISTER_KERNEL_BUILDER( \ - Name("StatelessRandomUniform") \ - .Device(DEVICE_GPU) \ - .HostMemory("shape") \ - .HostMemory("seed") \ - .TypeConstraint<TYPE>("dtype"), \ - StatelessRandomOp<GPUDevice, random::UniformDistribution< \ - random::PhiloxRandom, TYPE> >); \ - REGISTER_KERNEL_BUILDER( \ - Name("StatelessRandomNormal") \ - .Device(DEVICE_GPU) \ - .HostMemory("shape") \ - .HostMemory("seed") \ - .TypeConstraint<TYPE>("dtype"), \ - StatelessRandomOp<GPUDevice, random::NormalDistribution< \ - random::PhiloxRandom, TYPE> >); \ - REGISTER_KERNEL_BUILDER( \ - Name("StatelessTruncatedNormal") \ - .Device(DEVICE_GPU) \ - .HostMemory("shape") \ - .HostMemory("seed") \ - .TypeConstraint<TYPE>("dtype"), \ - StatelessRandomOp< \ - GPUDevice, \ - random::TruncatedNormalDistribution< \ - random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >); +TF_CALL_half(REGISTER_GPU); +TF_CALL_float(REGISTER_GPU); +TF_CALL_double(REGISTER_GPU); +TF_CALL_int32(REGISTER_INT_GPU); +TF_CALL_int64(REGISTER_INT_GPU); -TF_CALL_half(REGISTER); -TF_CALL_float(REGISTER); -TF_CALL_double(REGISTER); +#endif // GOOGLE_CUDA #undef REGISTER - -#endif // GOOGLE_CUDA +#undef REGISTER_INT +#undef REGISTER_CPU +#undef REGISTER_GPU +#undef REGISTER_INT_CPU +#undef REGISTER_INT_GPU } // namespace diff --git a/tensorflow/core/ops/stateless_random_ops.cc b/tensorflow/core/ops/stateless_random_ops.cc index 742709fb18..f919a21d60 100644 --- a/tensorflow/core/ops/stateless_random_ops.cc +++ b/tensorflow/core/ops/stateless_random_ops.cc @@ -19,42 +19,55 @@ limitations under the License. namespace tensorflow { using shape_inference::DimensionHandle; +using shape_inference::InferenceContext; using shape_inference::ShapeHandle; -static Status StatelessShape(shape_inference::InferenceContext* context) { +static Status StatelessShape(InferenceContext* c) { // Check seed shape ShapeHandle seed; - TF_RETURN_IF_ERROR(context->WithRank(context->input(1), 1, &seed)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &seed)); DimensionHandle unused; - TF_RETURN_IF_ERROR(context->WithValue(context->Dim(seed, 0), 2, &unused)); + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(seed, 0), 2, &unused)); // Set output shape ShapeHandle out; - TF_RETURN_IF_ERROR(context->MakeShapeFromShapeTensor(0, &out)); - context->set_output(0, out); + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out)); + c->set_output(0, out); return Status::OK(); } -#define REGISTER_STATELESS_OP(name) \ - REGISTER_OP(name) \ - .Input("shape: T") \ - .Input("seed: Tseed") \ - .Output("output: dtype") \ - .Attr("dtype: {half,float,double} = DT_FLOAT") \ - .Attr("T: {int32, int64} = DT_INT32") \ - .Attr("Tseed: {int32, int64} = DT_INT64") \ +#define REGISTER_STATELESS_OP(name) \ + REGISTER_OP(name) \ + .Input("shape: T") \ + .Input("seed: Tseed") \ + .Output("output: dtype") \ + .Attr("dtype: {half,bfloat16,float,double} = DT_FLOAT") \ + .Attr("T: {int32, int64} = DT_INT32") \ + .Attr("Tseed: {int32, int64} = DT_INT64") \ .SetShapeFn(StatelessShape) -// This op is exposed through contrib/stateless only. The interface may change. REGISTER_STATELESS_OP("StatelessRandomUniform"); - -// This op is exposed through contrib/stateless only. The interface may change. REGISTER_STATELESS_OP("StatelessRandomNormal"); - -// This op is exposed through contrib/stateless only. The interface may change. REGISTER_STATELESS_OP("StatelessTruncatedNormal"); -// This op is exposed through contrib/stateless only. The interface may change. +#undef REGISTER_STATELESS_OP + +REGISTER_OP("StatelessRandomUniformInt") + .Input("shape: T") + .Input("seed: Tseed") + .Input("minval: dtype") + .Input("maxval: dtype") + .Output("output: dtype") + .Attr("dtype: {int32, int64}") + .Attr("T: {int32, int64}") + .Attr("Tseed: {int32, int64} = DT_INT64") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + return StatelessShape(c); + }); + REGISTER_OP("StatelessMultinomial") .Input("logits: T") .Input("num_samples: int32") @@ -80,6 +93,4 @@ REGISTER_OP("StatelessMultinomial") return Status::OK(); }); -#undef REGISTER_STATELESS_OP - } // namespace tensorflow |