diff options
Diffstat (limited to 'tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py')
-rw-r--r-- | tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py | 154 |
1 files changed, 80 insertions, 74 deletions
diff --git a/tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py b/tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py index d724a5c014..ec5a13b7c6 100644 --- a/tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py +++ b/tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools + import numpy as np from tensorflow.contrib import stateless from tensorflow.python.framework import constant_op @@ -27,10 +29,6 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import random_ops from tensorflow.python.platform import test -CASES = [(stateless.stateless_random_uniform, random_ops.random_uniform), - (stateless.stateless_random_normal, random_ops.random_normal), - (stateless.stateless_truncated_normal, random_ops.truncated_normal)] - def invert_philox(key, value): """Invert the Philox bijection.""" @@ -51,90 +49,30 @@ def invert_philox(key, value): class StatelessOpsTest(test.TestCase): - def testMatchStateful(self): + def _test_match(self, cases): # Stateless ops should be the same as stateful ops on the first call # after seed scrambling. + cases = tuple(cases) key = 0x3ec8f720, 0x02461e29 for seed in (7, 17), (11, 5), (2, 3): preseed = invert_philox(key, (seed[0], 0, seed[1], 0)).astype(np.uint64) preseed = preseed[::2] | preseed[1::2] << 32 random_seed.set_random_seed(seed[0]) with self.test_session(use_gpu=True): - for stateless_op, stateful_op in CASES: - for shape in (), (3,), (2, 5): - stateful = stateful_op(shape, seed=seed[1]) - pure = stateless_op(shape, seed=preseed) - self.assertAllEqual(stateful.eval(), pure.eval()) + for stateless_op, stateful_op in cases: + stateful = stateful_op(seed=seed[1]) + pure = stateless_op(seed=preseed) + self.assertAllEqual(stateful.eval(), pure.eval()) - def testDeterminism(self): + def _test_determinism(self, cases): # Stateless values should be equal iff the seeds are equal (roughly) + cases = tuple(cases) with self.test_session(use_gpu=True): for seed_type in [dtypes.int32, dtypes.int64]: seed_t = array_ops.placeholder(seed_type, shape=[2]) seeds = [(x, y) for x in range(5) for y in range(5)] * 3 - for stateless_op, _ in CASES: - for shape in (), (3,), (2, 5): - pure = stateless_op(shape, seed=seed_t) - values = [(seed, pure.eval(feed_dict={seed_t: seed})) - for seed in seeds] - for s0, v0 in values: - for s1, v1 in values: - self.assertEqual(s0 == s1, np.all(v0 == v1)) - - def testShapeType(self): - with self.test_session(use_gpu=True): - for shape_dtype in [dtypes.int32, dtypes.int64]: - seed_t = array_ops.placeholder(dtypes.int64, shape=[2]) - seeds = [(x, y) for x in range(5) for y in range(5)] * 3 - for stateless_op, _ in CASES: - for shape in (), (3,), (2, 5): - pure = stateless_op(constant_op.constant(shape, dtype=shape_dtype), - seed=seed_t) - values = [(seed, pure.eval(feed_dict={seed_t: seed})) - for seed in seeds] - for s0, v0 in values: - for s1, v1 in values: - self.assertEqual(s0 == s1, np.all(v0 == v1)) - - def testMatchStatefulMultinomial(self): - # Stateless ops should be the same as stateful ops on the first call - # after seed scrambling. - key = 0x3ec8f720, 0x02461e29 - num_samples = 4 - for logits_dtype in np.float16, np.float32, np.float64: - for output_dtype in dtypes.int32, dtypes.int64: - for seed in (7, 17), (11, 5), (2, 3): - preseed = invert_philox(key, - (seed[0], 0, seed[1], 0)).astype(np.uint64) - preseed = preseed[::2] | preseed[1::2] << 32 - random_seed.set_random_seed(seed[0]) - with self.test_session(use_gpu=True): - for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5], [0.8, 0.2], - [0.25, 0.75]]): - logits_t = constant_op.constant(logits, dtype=logits_dtype) - stateful = random_ops.multinomial( - logits_t, - num_samples, - seed=seed[1], - output_dtype=output_dtype) - pure = stateless.stateless_multinomial( - logits_t, - num_samples, - seed=preseed, - output_dtype=output_dtype) - self.assertAllEqual(stateful.eval(), pure.eval()) - - def testDeterminismMultinomial(self): - # Stateless values should be equal iff the seeds are equal (roughly) - num_samples = 10 - with self.test_session(use_gpu=True): - for seed_type in [dtypes.int32, dtypes.int64]: - seed_t = array_ops.placeholder(seed_type, shape=[2]) - seeds = [(x, y) for x in range(5) for y in range(5)] * 3 - for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5], [0.8, 0.2], - [0.25, 0.75]]): - pure = stateless.stateless_multinomial( - logits, num_samples, seed=seed_t) + for stateless_op, _ in cases: + pure = stateless_op(seed=seed_t) values = [ (seed, pure.eval(feed_dict={seed_t: seed})) for seed in seeds ] @@ -142,6 +80,74 @@ class StatelessOpsTest(test.TestCase): for s1, v1 in values: self.assertEqual(s0 == s1, np.all(v0 == v1)) + def _float_cases(self, shape_dtypes=(None,)): + float_cases = ( + # Uniform distribution, with and without range + (stateless.stateless_random_uniform, random_ops.random_uniform, {}), + (stateless.stateless_random_uniform, random_ops.random_uniform, + dict(minval=2.2, maxval=7.1)), + # Normal distribution, with and without mean+stddev + (stateless.stateless_random_normal, random_ops.random_normal, {}), + (stateless.stateless_random_normal, random_ops.random_normal, + dict(mean=2, stddev=3)), + # Truncated normal distribution, with and without mean+stddev + (stateless.stateless_truncated_normal, random_ops.truncated_normal, {}), + (stateless.stateless_truncated_normal, random_ops.truncated_normal, + dict(mean=3, stddev=4)), + ) + for dtype in dtypes.float16, dtypes.float32, dtypes.float64: + for shape_dtype in shape_dtypes: + for shape in (), (3,), (2, 5): + if shape_dtype is not None: + shape = constant_op.constant(shape, dtype=shape_dtype) + for stateless_op, stateful_op, kwds in float_cases: + kwds = dict(shape=shape, dtype=dtype, **kwds) + yield (functools.partial(stateless_op, **kwds), + functools.partial(stateful_op, **kwds)) + + def _int_cases(self, shape_dtypes=(None,)): + for shape_dtype in shape_dtypes: + for shape in (), (3,), (2, 5): + if shape_dtype is not None: + shape = constant_op.constant(shape, dtype=shape_dtype) + for dtype in dtypes.int32, dtypes.int64: + kwds = dict(minval=2, maxval=11111, dtype=dtype, shape=shape) + yield (functools.partial(stateless.stateless_random_uniform, **kwds), + functools.partial(random_ops.random_uniform, **kwds)) + + def _multinomial_cases(self): + num_samples = 10 + for logits_dtype in np.float16, np.float32, np.float64: + for output_dtype in dtypes.int32, dtypes.int64: + for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5], [0.8, 0.2], + [0.25, 0.75]]): + kwds = dict( + logits=constant_op.constant(logits, dtype=logits_dtype), + num_samples=num_samples, + output_dtype=output_dtype) + yield (functools.partial(stateless.stateless_multinomial, **kwds), + functools.partial(random_ops.multinomial, **kwds)) + + def testMatchFloat(self): + self._test_match(self._float_cases()) + + def testMatchInt(self): + self._test_match(self._int_cases()) + + def testMatchMultinomial(self): + self._test_match(self._multinomial_cases()) + + def testDeterminismFloat(self): + self._test_determinism( + self._float_cases(shape_dtypes=(dtypes.int32, dtypes.int64))) + + def testDeterminismInt(self): + self._test_determinism( + self._int_cases(shape_dtypes=(dtypes.int32, dtypes.int64))) + + def testDeterminismMultinomial(self): + self._test_determinism(self._multinomial_cases()) + if __name__ == '__main__': test.main() |