aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests/random_ops_test.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-06 12:39:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-06 12:43:30 -0700
commit8f2e5f0b4a0221ca1573a40a68077326a32c9bc0 (patch)
treefff8d5d0285fab2169902ced613879485b692061 /tensorflow/compiler/tests/random_ops_test.py
parent8b460629e51356485d4da80d81f22e5911a64788 (diff)
[TF:XLA] Add a implementation of RandomShuffle.
PiperOrigin-RevId: 199511721
Diffstat (limited to 'tensorflow/compiler/tests/random_ops_test.py')
-rw-r--r--tensorflow/compiler/tests/random_ops_test.py38
1 files changed, 32 insertions, 6 deletions
diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py
index 70be22936a..f13dff9620 100644
--- a/tensorflow/compiler/tests/random_ops_test.py
+++ b/tensorflow/compiler/tests/random_ops_test.py
@@ -22,6 +22,8 @@ import numpy as np
from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.platform import googletest
@@ -47,18 +49,18 @@ class RandomOpsTest(XLATestCase):
# We use exact equality here. If the random-number generator is producing
# deterministic output, all three outputs will be bitwise identical.
self.assertTrue((not np.array_equal(y, z)) or
- (not np.array_equal(z, w)) or
- (not np.array_equal(y, w)))
+ (not np.array_equal(z, w)) or (not np.array_equal(y, w)))
def testRandomUniformIsNotConstant(self):
+
def rng(dtype):
- return random_ops.random_uniform(shape=[2], dtype=dtype,
- maxval=1000000)
+ return random_ops.random_uniform(shape=[2], dtype=dtype, maxval=1000000)
for dtype in self._random_types():
self._testRngIsNotConstant(rng, dtype)
def testRandomNormalIsNotConstant(self):
+
def rng(dtype):
return random_ops.random_normal(shape=[2], dtype=dtype)
@@ -70,13 +72,14 @@ class RandomOpsTest(XLATestCase):
for dtype in self._random_types():
with self.test_session() as sess:
with self.test_scope():
- x = random_ops.random_uniform(shape=[1000], dtype=dtype, minval=-2,
- maxval=33)
+ x = random_ops.random_uniform(
+ shape=[1000], dtype=dtype, minval=-2, maxval=33)
y = sess.run(x)
self.assertTrue((y >= -2).sum() == 1000)
self.assertTrue((y < 33).sum() == 1000)
def testTruncatedNormalIsNotConstant(self):
+
def rng(dtype):
return random_ops.truncated_normal(shape=[2], dtype=dtype)
@@ -94,6 +97,29 @@ class RandomOpsTest(XLATestCase):
self.assertTrue((y >= -2).sum() == count)
self.assertTrue((y <= 2).sum() == count)
+ def testShuffle1d(self):
+ with self.test_session() as sess:
+ with self.test_scope():
+ x = math_ops.range(20)
+ shuffle = random_ops.random_shuffle(x)
+ result = sess.run(shuffle)
+ expected = range(20)
+ # Compare sets to avoid randomness behavior changes but make sure still
+ # have all the values.
+ self.assertAllEqual(set(result), set(expected))
+
+ def testShuffle2d(self):
+ with self.test_session() as sess:
+ with self.test_scope():
+ x = array_ops.diag(math_ops.range(20))
+ shuffle = random_ops.random_shuffle(x)
+ result = sess.run(shuffle)
+ expected = np.diag(range(20)).flatten()
+ # Compare sets to avoid randomness behavior changes but make sure still
+ # have all the values.
+ self.assertAllEqual(len(result.flatten()), len(expected))
+ self.assertAllEqual(set(result.flatten()), set(expected))
+
if __name__ == '__main__':
googletest.main()