diff options
author | 2018-04-11 18:09:42 -0700 | |
---|---|---|
committer | 2018-04-11 18:11:58 -0700 | |
commit | 70d99359fcb9aa9efa955fab06227373c734728b (patch) | |
tree | e5c9d1c6cbed02be0a352f85f64b525c3dddcbe9 /tensorflow/contrib/stateless | |
parent | 1a721ecd9a9992d48c0deb3008b1fc8df297d300 (diff) |
Add `tf.contrib.stateless.stateless_multinomial()`.
This is a starting point for Dataset-compatible weighted sampling across a list of datasets.
PiperOrigin-RevId: 192540412
Diffstat (limited to 'tensorflow/contrib/stateless')
-rw-r--r-- | tensorflow/contrib/stateless/__init__.py | 2 | ||||
-rw-r--r-- | tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py | 46 |
2 files changed, 48 insertions, 0 deletions
diff --git a/tensorflow/contrib/stateless/__init__.py b/tensorflow/contrib/stateless/__init__.py index ca937546f5..0cca40f071 100644 --- a/tensorflow/contrib/stateless/__init__.py +++ b/tensorflow/contrib/stateless/__init__.py @@ -22,6 +22,7 @@ WARNING: These ops are in contrib, and are not stable. They should be consistent across multiple runs on the same hardware, but only for the same version of the code. +@@stateless_multinomial @@stateless_random_uniform @@stateless_random_normal @@stateless_truncated_normal @@ -37,6 +38,7 @@ from tensorflow.contrib.stateless.gen_stateless_random_ops import * from tensorflow.python.framework import ops from tensorflow.python.util.all_util import remove_undocumented +ops.NotDifferentiable("StatelessMultinomial") ops.NotDifferentiable("StatelessRandomNormal") ops.NotDifferentiable("StatelessRandomUniform") ops.NotDifferentiable("StatelessTruncatedNormal") diff --git a/tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py b/tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py index bea6341cfd..d724a5c014 100644 --- a/tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py +++ b/tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py @@ -96,6 +96,52 @@ class StatelessOpsTest(test.TestCase): for s1, v1 in values: self.assertEqual(s0 == s1, np.all(v0 == v1)) + def testMatchStatefulMultinomial(self): + # Stateless ops should be the same as stateful ops on the first call + # after seed scrambling. + key = 0x3ec8f720, 0x02461e29 + num_samples = 4 + for logits_dtype in np.float16, np.float32, np.float64: + for output_dtype in dtypes.int32, dtypes.int64: + for seed in (7, 17), (11, 5), (2, 3): + preseed = invert_philox(key, + (seed[0], 0, seed[1], 0)).astype(np.uint64) + preseed = preseed[::2] | preseed[1::2] << 32 + random_seed.set_random_seed(seed[0]) + with self.test_session(use_gpu=True): + for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5], [0.8, 0.2], + [0.25, 0.75]]): + logits_t = constant_op.constant(logits, dtype=logits_dtype) + stateful = random_ops.multinomial( + logits_t, + num_samples, + seed=seed[1], + output_dtype=output_dtype) + pure = stateless.stateless_multinomial( + logits_t, + num_samples, + seed=preseed, + output_dtype=output_dtype) + self.assertAllEqual(stateful.eval(), pure.eval()) + + def testDeterminismMultinomial(self): + # Stateless values should be equal iff the seeds are equal (roughly) + num_samples = 10 + with self.test_session(use_gpu=True): + for seed_type in [dtypes.int32, dtypes.int64]: + seed_t = array_ops.placeholder(seed_type, shape=[2]) + seeds = [(x, y) for x in range(5) for y in range(5)] * 3 + for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5], [0.8, 0.2], + [0.25, 0.75]]): + pure = stateless.stateless_multinomial( + logits, num_samples, seed=seed_t) + values = [ + (seed, pure.eval(feed_dict={seed_t: seed})) for seed in seeds + ] + for s0, v0 in values: + for s1, v1 in values: + self.assertEqual(s0 == s1, np.all(v0 == v1)) + if __name__ == '__main__': test.main() |