aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-05 18:02:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-05 18:02:42 -0700
commit5dcca3baca11de0687747e9b5ad8854b77fd097d (patch)
tree4eb40a90582b78285963bff749953dafd2feed03 /tensorflow/contrib
parent213d76a6ed77a696883502c53a3a4f81d2ee4042 (diff)
parent1e104d80826fed95f9fad6f07f68e35cae3527b2 (diff)
Merge pull request #22386 from girving:stateless
PiperOrigin-RevId: 215995215
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r--tensorflow/contrib/stateless/BUILD5
-rw-r--r--tensorflow/contrib/stateless/__init__.py9
-rw-r--r--tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py154
-rw-r--r--tensorflow/contrib/stateless/python/stateless_ops.py214
4 files changed, 299 insertions, 83 deletions
diff --git a/tensorflow/contrib/stateless/BUILD b/tensorflow/contrib/stateless/BUILD
index a217397c1a..e9ddec8889 100644
--- a/tensorflow/contrib/stateless/BUILD
+++ b/tensorflow/contrib/stateless/BUILD
@@ -11,7 +11,10 @@ load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
py_library(
name = "stateless",
- srcs = ["__init__.py"],
+ srcs = [
+ "__init__.py",
+ "python/stateless_ops.py",
+ ],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:framework_ops",
diff --git a/tensorflow/contrib/stateless/__init__.py b/tensorflow/contrib/stateless/__init__.py
index fe23fe0dd8..30d0a7ab6a 100644
--- a/tensorflow/contrib/stateless/__init__.py
+++ b/tensorflow/contrib/stateless/__init__.py
@@ -32,16 +32,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.framework import ops
-
# pylint: disable=wildcard-import
-from tensorflow.python.ops.gen_stateless_random_ops import *
+from tensorflow.contrib.stateless.python.stateless_ops import *
from tensorflow.python.util.all_util import remove_undocumented
-ops.NotDifferentiable("StatelessMultinomial")
-ops.NotDifferentiable("StatelessRandomNormal")
-ops.NotDifferentiable("StatelessRandomUniform")
-ops.NotDifferentiable("StatelessTruncatedNormal")
-
remove_undocumented(__name__)
diff --git a/tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py b/tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py
index d724a5c014..ec5a13b7c6 100644
--- a/tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py
+++ b/tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import functools
+
import numpy as np
from tensorflow.contrib import stateless
from tensorflow.python.framework import constant_op
@@ -27,10 +29,6 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
-CASES = [(stateless.stateless_random_uniform, random_ops.random_uniform),
- (stateless.stateless_random_normal, random_ops.random_normal),
- (stateless.stateless_truncated_normal, random_ops.truncated_normal)]
-
def invert_philox(key, value):
"""Invert the Philox bijection."""
@@ -51,90 +49,30 @@ def invert_philox(key, value):
class StatelessOpsTest(test.TestCase):
- def testMatchStateful(self):
+ def _test_match(self, cases):
# Stateless ops should be the same as stateful ops on the first call
# after seed scrambling.
+ cases = tuple(cases)
key = 0x3ec8f720, 0x02461e29
for seed in (7, 17), (11, 5), (2, 3):
preseed = invert_philox(key, (seed[0], 0, seed[1], 0)).astype(np.uint64)
preseed = preseed[::2] | preseed[1::2] << 32
random_seed.set_random_seed(seed[0])
with self.test_session(use_gpu=True):
- for stateless_op, stateful_op in CASES:
- for shape in (), (3,), (2, 5):
- stateful = stateful_op(shape, seed=seed[1])
- pure = stateless_op(shape, seed=preseed)
- self.assertAllEqual(stateful.eval(), pure.eval())
+ for stateless_op, stateful_op in cases:
+ stateful = stateful_op(seed=seed[1])
+ pure = stateless_op(seed=preseed)
+ self.assertAllEqual(stateful.eval(), pure.eval())
- def testDeterminism(self):
+ def _test_determinism(self, cases):
# Stateless values should be equal iff the seeds are equal (roughly)
+ cases = tuple(cases)
with self.test_session(use_gpu=True):
for seed_type in [dtypes.int32, dtypes.int64]:
seed_t = array_ops.placeholder(seed_type, shape=[2])
seeds = [(x, y) for x in range(5) for y in range(5)] * 3
- for stateless_op, _ in CASES:
- for shape in (), (3,), (2, 5):
- pure = stateless_op(shape, seed=seed_t)
- values = [(seed, pure.eval(feed_dict={seed_t: seed}))
- for seed in seeds]
- for s0, v0 in values:
- for s1, v1 in values:
- self.assertEqual(s0 == s1, np.all(v0 == v1))
-
- def testShapeType(self):
- with self.test_session(use_gpu=True):
- for shape_dtype in [dtypes.int32, dtypes.int64]:
- seed_t = array_ops.placeholder(dtypes.int64, shape=[2])
- seeds = [(x, y) for x in range(5) for y in range(5)] * 3
- for stateless_op, _ in CASES:
- for shape in (), (3,), (2, 5):
- pure = stateless_op(constant_op.constant(shape, dtype=shape_dtype),
- seed=seed_t)
- values = [(seed, pure.eval(feed_dict={seed_t: seed}))
- for seed in seeds]
- for s0, v0 in values:
- for s1, v1 in values:
- self.assertEqual(s0 == s1, np.all(v0 == v1))
-
- def testMatchStatefulMultinomial(self):
- # Stateless ops should be the same as stateful ops on the first call
- # after seed scrambling.
- key = 0x3ec8f720, 0x02461e29
- num_samples = 4
- for logits_dtype in np.float16, np.float32, np.float64:
- for output_dtype in dtypes.int32, dtypes.int64:
- for seed in (7, 17), (11, 5), (2, 3):
- preseed = invert_philox(key,
- (seed[0], 0, seed[1], 0)).astype(np.uint64)
- preseed = preseed[::2] | preseed[1::2] << 32
- random_seed.set_random_seed(seed[0])
- with self.test_session(use_gpu=True):
- for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5], [0.8, 0.2],
- [0.25, 0.75]]):
- logits_t = constant_op.constant(logits, dtype=logits_dtype)
- stateful = random_ops.multinomial(
- logits_t,
- num_samples,
- seed=seed[1],
- output_dtype=output_dtype)
- pure = stateless.stateless_multinomial(
- logits_t,
- num_samples,
- seed=preseed,
- output_dtype=output_dtype)
- self.assertAllEqual(stateful.eval(), pure.eval())
-
- def testDeterminismMultinomial(self):
- # Stateless values should be equal iff the seeds are equal (roughly)
- num_samples = 10
- with self.test_session(use_gpu=True):
- for seed_type in [dtypes.int32, dtypes.int64]:
- seed_t = array_ops.placeholder(seed_type, shape=[2])
- seeds = [(x, y) for x in range(5) for y in range(5)] * 3
- for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5], [0.8, 0.2],
- [0.25, 0.75]]):
- pure = stateless.stateless_multinomial(
- logits, num_samples, seed=seed_t)
+ for stateless_op, _ in cases:
+ pure = stateless_op(seed=seed_t)
values = [
(seed, pure.eval(feed_dict={seed_t: seed})) for seed in seeds
]
@@ -142,6 +80,74 @@ class StatelessOpsTest(test.TestCase):
for s1, v1 in values:
self.assertEqual(s0 == s1, np.all(v0 == v1))
+ def _float_cases(self, shape_dtypes=(None,)):
+ float_cases = (
+ # Uniform distribution, with and without range
+ (stateless.stateless_random_uniform, random_ops.random_uniform, {}),
+ (stateless.stateless_random_uniform, random_ops.random_uniform,
+ dict(minval=2.2, maxval=7.1)),
+ # Normal distribution, with and without mean+stddev
+ (stateless.stateless_random_normal, random_ops.random_normal, {}),
+ (stateless.stateless_random_normal, random_ops.random_normal,
+ dict(mean=2, stddev=3)),
+ # Truncated normal distribution, with and without mean+stddev
+ (stateless.stateless_truncated_normal, random_ops.truncated_normal, {}),
+ (stateless.stateless_truncated_normal, random_ops.truncated_normal,
+ dict(mean=3, stddev=4)),
+ )
+ for dtype in dtypes.float16, dtypes.float32, dtypes.float64:
+ for shape_dtype in shape_dtypes:
+ for shape in (), (3,), (2, 5):
+ if shape_dtype is not None:
+ shape = constant_op.constant(shape, dtype=shape_dtype)
+ for stateless_op, stateful_op, kwds in float_cases:
+ kwds = dict(shape=shape, dtype=dtype, **kwds)
+ yield (functools.partial(stateless_op, **kwds),
+ functools.partial(stateful_op, **kwds))
+
+ def _int_cases(self, shape_dtypes=(None,)):
+ for shape_dtype in shape_dtypes:
+ for shape in (), (3,), (2, 5):
+ if shape_dtype is not None:
+ shape = constant_op.constant(shape, dtype=shape_dtype)
+ for dtype in dtypes.int32, dtypes.int64:
+ kwds = dict(minval=2, maxval=11111, dtype=dtype, shape=shape)
+ yield (functools.partial(stateless.stateless_random_uniform, **kwds),
+ functools.partial(random_ops.random_uniform, **kwds))
+
+ def _multinomial_cases(self):
+ num_samples = 10
+ for logits_dtype in np.float16, np.float32, np.float64:
+ for output_dtype in dtypes.int32, dtypes.int64:
+ for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5], [0.8, 0.2],
+ [0.25, 0.75]]):
+ kwds = dict(
+ logits=constant_op.constant(logits, dtype=logits_dtype),
+ num_samples=num_samples,
+ output_dtype=output_dtype)
+ yield (functools.partial(stateless.stateless_multinomial, **kwds),
+ functools.partial(random_ops.multinomial, **kwds))
+
+ def testMatchFloat(self):
+ self._test_match(self._float_cases())
+
+ def testMatchInt(self):
+ self._test_match(self._int_cases())
+
+ def testMatchMultinomial(self):
+ self._test_match(self._multinomial_cases())
+
+ def testDeterminismFloat(self):
+ self._test_determinism(
+ self._float_cases(shape_dtypes=(dtypes.int32, dtypes.int64)))
+
+ def testDeterminismInt(self):
+ self._test_determinism(
+ self._int_cases(shape_dtypes=(dtypes.int32, dtypes.int64)))
+
+ def testDeterminismMultinomial(self):
+ self._test_determinism(self._multinomial_cases())
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/stateless/python/stateless_ops.py b/tensorflow/contrib/stateless/python/stateless_ops.py
new file mode 100644
index 0000000000..1449825c83
--- /dev/null
+++ b/tensorflow/contrib/stateless/python/stateless_ops.py
@@ -0,0 +1,214 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Stateless random ops which take seed as a tensor input."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.ops import gen_stateless_random_ops
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import math_ops
+
+ops.NotDifferentiable("StatelessMultinomial")
+ops.NotDifferentiable("StatelessRandomNormal")
+ops.NotDifferentiable("StatelessRandomUniform")
+ops.NotDifferentiable("StatelessRandomUniformInt")
+ops.NotDifferentiable("StatelessTruncatedNormal")
+
+
+def stateless_random_uniform(shape,
+ seed,
+ minval=0,
+ maxval=None,
+ dtype=dtypes.float32,
+ name=None):
+ """Outputs deterministic pseudorandom values from a uniform distribution.
+
+ This is a stateless version of `tf.random_uniform`: if run twice with the
+ same seeds, it will produce the same pseudorandom numbers. The output is
+ consistent across multiple runs on the same hardware (and between CPU
+ and GPU), but may change between versions of TensorFlow or on non-CPU/GPU
+ hardware.
+
+ The generated values follow a uniform distribution in the range
+ `[minval, maxval)`. The lower bound `minval` is included in the range, while
+ the upper bound `maxval` is excluded.
+
+ For floats, the default range is `[0, 1)`. For ints, at least `maxval` must
+ be specified explicitly.
+
+ In the integer case, the random integers are slightly biased unless
+ `maxval - minval` is an exact power of two. The bias is small for values of
+ `maxval - minval` significantly smaller than the range of the output (either
+ `2**32` or `2**64`).
+
+ Args:
+ shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
+ seed: A shape [2] integer Tensor of seeds to the random number generator.
+ minval: A 0-D Tensor or Python value of type `dtype`. The lower bound on the
+ range of random values to generate. Defaults to 0.
+ maxval: A 0-D Tensor or Python value of type `dtype`. The upper bound on the
+ range of random values to generate. Defaults to 1 if `dtype` is floating
+ point.
+ dtype: The type of the output: `float16`, `float32`, `float64`, `int32`, or
+ `int64`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A tensor of the specified shape filled with random uniform values.
+
+ Raises:
+ ValueError: If `dtype` is integral and `maxval` is not specified.
+ """
+ dtype = dtypes.as_dtype(dtype)
+ if dtype not in (dtypes.float16, dtypes.bfloat16, dtypes.float32,
+ dtypes.float64, dtypes.int32, dtypes.int64):
+ raise ValueError("Invalid dtype %r" % dtype)
+ if maxval is None:
+ if dtype.is_integer:
+ raise ValueError("Must specify maxval for integer dtype %r" % dtype)
+ maxval = 1
+ with ops.name_scope(name, "stateless_random_uniform",
+ [shape, seed, minval, maxval]) as name:
+ shape = random_ops._ShapeTensor(shape) # pylint: disable=protected-access
+ minval = ops.convert_to_tensor(minval, dtype=dtype, name="min")
+ maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max")
+ if dtype.is_integer:
+ return gen_stateless_random_ops.stateless_random_uniform_int(
+ shape, seed=seed, minval=minval, maxval=maxval, name=name)
+ else:
+ rnd = gen_stateless_random_ops.stateless_random_uniform(
+ shape, seed=seed, dtype=dtype)
+ return math_ops.add(rnd * (maxval - minval), minval, name=name)
+
+
+def stateless_random_normal(shape,
+ seed,
+ mean=0.0,
+ stddev=1.0,
+ dtype=dtypes.float32,
+ name=None):
+ """Outputs deterministic pseudorandom values from a normal distribution.
+
+ This is a stateless version of `tf.random_normal`: if run twice with the
+ same seeds, it will produce the same pseudorandom numbers. The output is
+ consistent across multiple runs on the same hardware (and between CPU
+ and GPU), but may change between versions of TensorFlow or on non-CPU/GPU
+ hardware.
+
+ Args:
+ shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
+ seed: A shape [2] integer Tensor of seeds to the random number generator.
+ mean: A 0-D Tensor or Python value of type `dtype`. The mean of the normal
+ distribution.
+ stddev: A 0-D Tensor or Python value of type `dtype`. The standard deviation
+ of the normal distribution.
+ dtype: The type of the output.
+ name: A name for the operation (optional).
+
+ Returns:
+ A tensor of the specified shape filled with random normal values.
+ """
+ with ops.name_scope(name, "stateless_random_normal",
+ [shape, seed, mean, stddev]) as name:
+ shape = random_ops._ShapeTensor(shape) # pylint: disable=protected-access
+ mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
+ stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
+ rnd = gen_stateless_random_ops.stateless_random_normal(shape, seed, dtype)
+ return math_ops.add(rnd * stddev, mean, name=name)
+
+
+def stateless_truncated_normal(shape,
+ seed,
+ mean=0.0,
+ stddev=1.0,
+ dtype=dtypes.float32,
+ name=None):
+ """Outputs deterministic pseudorandom values, truncated normally distributed.
+
+ This is a stateless version of `tf.truncated_normal`: if run twice with the
+ same seeds, it will produce the same pseudorandom numbers. The output is
+ consistent across multiple runs on the same hardware (and between CPU
+ and GPU), but may change between versions of TensorFlow or on non-CPU/GPU
+ hardware.
+
+ The generated values follow a normal distribution with specified mean and
+ standard deviation, except that values whose magnitude is more than 2 standard
+ deviations from the mean are dropped and re-picked.
+
+ Args:
+ shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
+ seed: A shape [2] integer Tensor of seeds to the random number generator.
+ mean: A 0-D Tensor or Python value of type `dtype`. The mean of the
+ truncated normal distribution.
+ stddev: A 0-D Tensor or Python value of type `dtype`. The standard deviation
+ of the normal distribution, before truncation.
+ dtype: The type of the output.
+ name: A name for the operation (optional).
+
+ Returns:
+ A tensor of the specified shape filled with random truncated normal values.
+ """
+ with ops.name_scope(name, "stateless_truncated_normal",
+ [shape, seed, mean, stddev]) as name:
+ shape = random_ops._ShapeTensor(shape) # pylint: disable=protected-access
+ mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
+ stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
+ rnd = gen_stateless_random_ops.stateless_truncated_normal(
+ shape, seed, dtype)
+ return math_ops.add(rnd * stddev, mean, name=name)
+
+
+def stateless_multinomial(logits,
+ num_samples,
+ seed,
+ output_dtype=dtypes.int64,
+ name=None):
+ """Draws deterministic pseudorandom samples from a multinomial distribution.
+
+ This is a stateless version of `tf.multinomial`: if run twice with the
+ same seeds, it will produce the same pseudorandom numbers. The output is
+ consistent across multiple runs on the same hardware (and between CPU
+ and GPU), but may change between versions of TensorFlow or on non-CPU/GPU
+ hardware.
+
+ Example:
+
+ ```python
+ # samples has shape [1, 5], where each value is either 0 or 1 with equal
+ # probability.
+ samples = tf.contrib.stateless.stateless_multinomial(
+ tf.log([[10., 10.]]), 5, seed=[7, 17])
+ ```
+
+ Args:
+ logits: 2-D Tensor with shape `[batch_size, num_classes]`. Each slice
+ `[i, :]` represents the unnormalized log-probabilities for all classes.
+ num_samples: 0-D. Number of independent samples to draw for each row slice.
+ seed: A shape [2] integer Tensor of seeds to the random number generator.
+ name: Optional name for the operation.
+ output_dtype: integer type to use for the output. Defaults to int64.
+
+ Returns:
+ The drawn samples of shape `[batch_size, num_samples]`.
+ """
+ with ops.name_scope(name, "stateless_multinomial", [logits, seed]):
+ logits = ops.convert_to_tensor(logits, name="logits")
+ return gen_stateless_random_ops.stateless_multinomial(
+ logits, num_samples, seed, output_dtype=output_dtype)