aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests/stateless_random_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tests/stateless_random_ops_test.py')
-rw-r--r--tensorflow/compiler/tests/stateless_random_ops_test.py9
1 files changed, 9 insertions, 0 deletions
diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py
index 4336ebdbd1..b6f8390a45 100644
--- a/tensorflow/compiler/tests/stateless_random_ops_test.py
+++ b/tensorflow/compiler/tests/stateless_random_ops_test.py
@@ -86,6 +86,15 @@ class StatelessRandomOpsTest(XLATestCase):
# seed were not fixed.
self.assertTrue(self._chi_squared(y, 10) < 16.92)
+ def testRandomNormalIsFinite(self):
+ with self.test_session() as sess, self.test_scope():
+ for dtype in self._random_types():
+ seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
+ x = stateless.stateless_random_uniform(
+ shape=[10000], seed=seed_t, dtype=dtype)
+ y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]})
+ self.assertTrue(np.all(np.isfinite(y)))
+
def _normal_cdf(self, x):
"""Cumulative distribution function for a standard normal distribution."""
return 0.5 + 0.5 * np.vectorize(math.erf)(x / math.sqrt(2))