diff options
author | 2018-06-06 12:39:44 -0700 | |
---|---|---|
committer | 2018-06-06 12:43:30 -0700 | |
commit | 8f2e5f0b4a0221ca1573a40a68077326a32c9bc0 (patch) | |
tree | fff8d5d0285fab2169902ced613879485b692061 /tensorflow/compiler/tests/random_ops_test.py | |
parent | 8b460629e51356485d4da80d81f22e5911a64788 (diff) |
[TF:XLA] Add a implementation of RandomShuffle.
PiperOrigin-RevId: 199511721
Diffstat (limited to 'tensorflow/compiler/tests/random_ops_test.py')
-rw-r--r-- | tensorflow/compiler/tests/random_ops_test.py | 38 |
1 files changed, 32 insertions, 6 deletions
diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py index 70be22936a..f13dff9620 100644 --- a/tensorflow/compiler/tests/random_ops_test.py +++ b/tensorflow/compiler/tests/random_ops_test.py @@ -22,6 +22,8 @@ import numpy as np from tensorflow.compiler.tests.xla_test import XLATestCase from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.platform import googletest @@ -47,18 +49,18 @@ class RandomOpsTest(XLATestCase): # We use exact equality here. If the random-number generator is producing # deterministic output, all three outputs will be bitwise identical. self.assertTrue((not np.array_equal(y, z)) or - (not np.array_equal(z, w)) or - (not np.array_equal(y, w))) + (not np.array_equal(z, w)) or (not np.array_equal(y, w))) def testRandomUniformIsNotConstant(self): + def rng(dtype): - return random_ops.random_uniform(shape=[2], dtype=dtype, - maxval=1000000) + return random_ops.random_uniform(shape=[2], dtype=dtype, maxval=1000000) for dtype in self._random_types(): self._testRngIsNotConstant(rng, dtype) def testRandomNormalIsNotConstant(self): + def rng(dtype): return random_ops.random_normal(shape=[2], dtype=dtype) @@ -70,13 +72,14 @@ class RandomOpsTest(XLATestCase): for dtype in self._random_types(): with self.test_session() as sess: with self.test_scope(): - x = random_ops.random_uniform(shape=[1000], dtype=dtype, minval=-2, - maxval=33) + x = random_ops.random_uniform( + shape=[1000], dtype=dtype, minval=-2, maxval=33) y = sess.run(x) self.assertTrue((y >= -2).sum() == 1000) self.assertTrue((y < 33).sum() == 1000) def testTruncatedNormalIsNotConstant(self): + def rng(dtype): return random_ops.truncated_normal(shape=[2], dtype=dtype) @@ -94,6 +97,29 @@ class RandomOpsTest(XLATestCase): self.assertTrue((y >= -2).sum() == count) self.assertTrue((y <= 2).sum() == count) + def testShuffle1d(self): + with self.test_session() as sess: + with self.test_scope(): + x = math_ops.range(20) + shuffle = random_ops.random_shuffle(x) + result = sess.run(shuffle) + expected = range(20) + # Compare sets to avoid randomness behavior changes but make sure still + # have all the values. + self.assertAllEqual(set(result), set(expected)) + + def testShuffle2d(self): + with self.test_session() as sess: + with self.test_scope(): + x = array_ops.diag(math_ops.range(20)) + shuffle = random_ops.random_shuffle(x) + result = sess.run(shuffle) + expected = np.diag(range(20)).flatten() + # Compare sets to avoid randomness behavior changes but make sure still + # have all the values. + self.assertAllEqual(len(result.flatten()), len(expected)) + self.assertAllEqual(set(result.flatten()), set(expected)) + if __name__ == '__main__': googletest.main() |