aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/stateless
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-04-11 18:09:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-11 18:11:58 -0700
commit70d99359fcb9aa9efa955fab06227373c734728b (patch)
treee5c9d1c6cbed02be0a352f85f64b525c3dddcbe9 /tensorflow/contrib/stateless
parent1a721ecd9a9992d48c0deb3008b1fc8df297d300 (diff)
Add `tf.contrib.stateless.stateless_multinomial()`.
This is a starting point for Dataset-compatible weighted sampling across a list of datasets. PiperOrigin-RevId: 192540412
Diffstat (limited to 'tensorflow/contrib/stateless')
-rw-r--r--tensorflow/contrib/stateless/__init__.py2
-rw-r--r--tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py46
2 files changed, 48 insertions, 0 deletions
diff --git a/tensorflow/contrib/stateless/__init__.py b/tensorflow/contrib/stateless/__init__.py
index ca937546f5..0cca40f071 100644
--- a/tensorflow/contrib/stateless/__init__.py
+++ b/tensorflow/contrib/stateless/__init__.py
@@ -22,6 +22,7 @@ WARNING: These ops are in contrib, and are not stable. They should be
consistent across multiple runs on the same hardware, but only for the same
version of the code.
+@@stateless_multinomial
@@stateless_random_uniform
@@stateless_random_normal
@@stateless_truncated_normal
@@ -37,6 +38,7 @@ from tensorflow.contrib.stateless.gen_stateless_random_ops import *
from tensorflow.python.framework import ops
from tensorflow.python.util.all_util import remove_undocumented
+ops.NotDifferentiable("StatelessMultinomial")
ops.NotDifferentiable("StatelessRandomNormal")
ops.NotDifferentiable("StatelessRandomUniform")
ops.NotDifferentiable("StatelessTruncatedNormal")
diff --git a/tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py b/tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py
index bea6341cfd..d724a5c014 100644
--- a/tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py
+++ b/tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py
@@ -96,6 +96,52 @@ class StatelessOpsTest(test.TestCase):
for s1, v1 in values:
self.assertEqual(s0 == s1, np.all(v0 == v1))
+ def testMatchStatefulMultinomial(self):
+ # Stateless ops should be the same as stateful ops on the first call
+ # after seed scrambling.
+ key = 0x3ec8f720, 0x02461e29
+ num_samples = 4
+ for logits_dtype in np.float16, np.float32, np.float64:
+ for output_dtype in dtypes.int32, dtypes.int64:
+ for seed in (7, 17), (11, 5), (2, 3):
+ preseed = invert_philox(key,
+ (seed[0], 0, seed[1], 0)).astype(np.uint64)
+ preseed = preseed[::2] | preseed[1::2] << 32
+ random_seed.set_random_seed(seed[0])
+ with self.test_session(use_gpu=True):
+ for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5], [0.8, 0.2],
+ [0.25, 0.75]]):
+ logits_t = constant_op.constant(logits, dtype=logits_dtype)
+ stateful = random_ops.multinomial(
+ logits_t,
+ num_samples,
+ seed=seed[1],
+ output_dtype=output_dtype)
+ pure = stateless.stateless_multinomial(
+ logits_t,
+ num_samples,
+ seed=preseed,
+ output_dtype=output_dtype)
+ self.assertAllEqual(stateful.eval(), pure.eval())
+
+ def testDeterminismMultinomial(self):
+ # Stateless values should be equal iff the seeds are equal (roughly)
+ num_samples = 10
+ with self.test_session(use_gpu=True):
+ for seed_type in [dtypes.int32, dtypes.int64]:
+ seed_t = array_ops.placeholder(seed_type, shape=[2])
+ seeds = [(x, y) for x in range(5) for y in range(5)] * 3
+ for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5], [0.8, 0.2],
+ [0.25, 0.75]]):
+ pure = stateless.stateless_multinomial(
+ logits, num_samples, seed=seed_t)
+ values = [
+ (seed, pure.eval(feed_dict={seed_t: seed})) for seed in seeds
+ ]
+ for s0, v0 in values:
+ for s1, v1 in values:
+ self.assertEqual(s0 == s1, np.all(v0 == v1))
+
if __name__ == '__main__':
test.main()