# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests for stateless random-number generation ops.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import math import numpy as np from tensorflow.compiler.tests import xla_test from tensorflow.contrib import stateless from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops.distributions import special_math from tensorflow.python.platform import test class StatelessRandomOpsTest(xla_test.XLATestCase): """Test cases for stateless random-number generator operators.""" def _random_types(self): return self.float_types & {dtypes.float32, dtypes.float64} def testDeterminism(self): # Stateless values should be equal iff the seeds are equal (roughly) with self.cached_session(), self.test_scope(): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) seeds = [(x, y) for x in range(5) for y in range(5)] * 3 for stateless_op in [ stateless.stateless_random_uniform, stateless.stateless_random_normal ]: for shape in (), (3,), (2, 5): for dtype in self._random_types(): pure = stateless_op(shape, seed=seed_t, dtype=dtype) values = [(seed, pure.eval(feed_dict={ seed_t: seed })) for seed in seeds] for s0, v0 in values: for s1, v1 in values: self.assertEqual(s0 == s1, np.all(v0 == v1)) def testRandomUniformIsInRange(self): with self.cached_session() as sess, self.test_scope(): for dtype in self._random_types(): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) x = stateless.stateless_random_uniform( shape=[1000], seed=seed_t, dtype=dtype) y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]}) self.assertTrue(np.all(y >= 0)) self.assertTrue(np.all(y < 1)) def _chi_squared(self, x, bins): """Pearson's Chi-squared test.""" x = np.ravel(x) n = len(x) histogram, _ = np.histogram(x, bins=bins, range=(0, 1)) expected = n / float(bins) return np.sum(np.square(histogram - expected) / expected) def testDistributionOfStatelessRandomUniform(self): """Use Pearson's Chi-squared test to test for uniformity.""" with self.cached_session() as sess, self.test_scope(): for dtype in self._random_types(): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) n = 1000 x = stateless.stateless_random_uniform( shape=[n], seed=seed_t, dtype=dtype) y = sess.run(x, {seed_t: [565656, 121212]}) # Tests that the values are distributed amongst 10 bins with equal # probability. 16.92 is the Chi^2 value for 9 degrees of freedom with # p=0.05. This test is probabilistic and would be flaky if the random # seed were not fixed. self.assertTrue(self._chi_squared(y, 10) < 16.92) def testRandomNormalIsFinite(self): with self.cached_session() as sess, self.test_scope(): for dtype in self._random_types(): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) x = stateless.stateless_random_normal( shape=[10000], seed=seed_t, dtype=dtype) y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]}) self.assertTrue(np.all(np.isfinite(y))) def _normal_cdf(self, x): """Cumulative distribution function for a standard normal distribution.""" return 0.5 + 0.5 * np.vectorize(math.erf)(x / math.sqrt(2)) def _anderson_darling(self, x): """Anderson-Darling test for a standard normal distribution.""" x = np.sort(np.ravel(x)) n = len(x) i = np.linspace(1, n, n) z = np.sum((2 * i - 1) * np.log(self._normal_cdf(x)) + (2 * (n - i) + 1) * np.log(1 - self._normal_cdf(x))) return -n - z / n def testDistributionOfStatelessRandomNormal(self): """Use Anderson-Darling test to test distribution appears normal.""" with self.cached_session() as sess, self.test_scope(): for dtype in self._random_types(): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) n = 1000 x = stateless.stateless_random_normal( shape=[n], seed=seed_t, dtype=dtype) y = sess.run(x, {seed_t: [25252, 314159]}) # The constant 2.492 is the 5% critical value for the Anderson-Darling # test where the mean and variance are known. This test is probabilistic # so to avoid flakiness the seed is fixed. self.assertTrue(self._anderson_darling(y) < 2.492) def testTruncatedNormalIsInRange(self): for dtype in self._random_types(): with self.cached_session() as sess, self.test_scope(): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) n = 10000000 x = stateless.stateless_truncated_normal( shape=[n], seed=seed_t, dtype=dtype) y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]}) def normal_cdf(x): return .5 * math.erfc(-x / math.sqrt(2)) def normal_pdf(x): return math.exp(-(x**2) / 2.) / math.sqrt(2 * math.pi) def probit(x, sess=sess): return sess.run(special_math.ndtri(x)) a = -2. b = 2. mu = 0. sigma = 1. alpha = (a - mu) / sigma beta = (b - mu) / sigma z = normal_cdf(beta) - normal_cdf(alpha) self.assertTrue((y >= a).sum() == n) self.assertTrue((y <= b).sum() == n) # For more information on these calculations, see: # Burkardt, John. "The Truncated Normal Distribution". # Department of Scientific Computing website. Florida State University. expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma actual_mean = np.mean(y) self.assertAllClose(actual_mean, expected_mean, atol=5e-4) expected_median = mu + probit( (normal_cdf(alpha) + normal_cdf(beta)) / 2.) * sigma actual_median = np.median(y) self.assertAllClose(actual_median, expected_median, atol=8e-4) expected_variance = sigma**2 * (1 + ( (alpha * normal_pdf(alpha) - beta * normal_pdf(beta)) / z) - ( (normal_pdf(alpha) - normal_pdf(beta)) / z)**2) actual_variance = np.var(y) self.assertAllClose(actual_variance, expected_variance, rtol=1e-3) if __name__ == '__main__': test.main()