aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/distributions/student_t_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/distributions/student_t_test.py')
-rw-r--r--tensorflow/python/kernel_tests/distributions/student_t_test.py516
1 files changed, 516 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/distributions/student_t_test.py b/tensorflow/python/kernel_tests/distributions/student_t_test.py
new file mode 100644
index 0000000000..f1150de58e
--- /dev/null
+++ b/tensorflow/python/kernel_tests/distributions/student_t_test.py
@@ -0,0 +1,516 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Student t distribution."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import importlib
+import math
+
+import numpy as np
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import random_seed
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops.distributions import student_t
+from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging
+
+
+def try_import(name): # pylint: disable=invalid-name
+ module = None
+ try:
+ module = importlib.import_module(name)
+ except ImportError as e:
+ tf_logging.warning("Could not import %s: %s" % (name, str(e)))
+ return module
+
+
+stats = try_import("scipy.stats")
+
+
+class StudentTTest(test.TestCase):
+
+ def testStudentPDFAndLogPDF(self):
+ with self.test_session():
+ batch_size = 6
+ df = constant_op.constant([3.] * batch_size)
+ mu = constant_op.constant([7.] * batch_size)
+ sigma = constant_op.constant([8.] * batch_size)
+ df_v = 3.
+ mu_v = 7.
+ sigma_v = 8.
+ t = np.array([-2.5, 2.5, 8., 0., -1., 2.], dtype=np.float32)
+ student = student_t.StudentT(df, loc=mu, scale=-sigma)
+
+ log_pdf = student.log_prob(t)
+ self.assertEquals(log_pdf.get_shape(), (6,))
+ log_pdf_values = log_pdf.eval()
+ pdf = student.prob(t)
+ self.assertEquals(pdf.get_shape(), (6,))
+ pdf_values = pdf.eval()
+
+ if not stats:
+ return
+
+ expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v)
+ expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v)
+ self.assertAllClose(expected_log_pdf, log_pdf_values)
+ self.assertAllClose(np.log(expected_pdf), log_pdf_values)
+ self.assertAllClose(expected_pdf, pdf_values)
+ self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
+
+ def testStudentLogPDFMultidimensional(self):
+ with self.test_session():
+ batch_size = 6
+ df = constant_op.constant([[1.5, 7.2]] * batch_size)
+ mu = constant_op.constant([[3., -3.]] * batch_size)
+ sigma = constant_op.constant([[-math.sqrt(10.), math.sqrt(15.)]] *
+ batch_size)
+ df_v = np.array([1.5, 7.2])
+ mu_v = np.array([3., -3.])
+ sigma_v = np.array([np.sqrt(10.), np.sqrt(15.)])
+ t = np.array([[-2.5, 2.5, 4., 0., -1., 2.]], dtype=np.float32).T
+ student = student_t.StudentT(df, loc=mu, scale=sigma)
+ log_pdf = student.log_prob(t)
+ log_pdf_values = log_pdf.eval()
+ self.assertEqual(log_pdf.get_shape(), (6, 2))
+ pdf = student.prob(t)
+ pdf_values = pdf.eval()
+ self.assertEqual(pdf.get_shape(), (6, 2))
+
+ if not stats:
+ return
+ expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v)
+ expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v)
+ self.assertAllClose(expected_log_pdf, log_pdf_values)
+ self.assertAllClose(np.log(expected_pdf), log_pdf_values)
+ self.assertAllClose(expected_pdf, pdf_values)
+ self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
+
+ def testStudentCDFAndLogCDF(self):
+ with self.test_session():
+ batch_size = 6
+ df = constant_op.constant([3.] * batch_size)
+ mu = constant_op.constant([7.] * batch_size)
+ sigma = constant_op.constant([-8.] * batch_size)
+ df_v = 3.
+ mu_v = 7.
+ sigma_v = 8.
+ t = np.array([-2.5, 2.5, 8., 0., -1., 2.], dtype=np.float32)
+ student = student_t.StudentT(df, loc=mu, scale=sigma)
+
+ log_cdf = student.log_cdf(t)
+ self.assertEquals(log_cdf.get_shape(), (6,))
+ log_cdf_values = log_cdf.eval()
+ cdf = student.cdf(t)
+ self.assertEquals(cdf.get_shape(), (6,))
+ cdf_values = cdf.eval()
+
+ if not stats:
+ return
+ expected_log_cdf = stats.t.logcdf(t, df_v, loc=mu_v, scale=sigma_v)
+ expected_cdf = stats.t.cdf(t, df_v, loc=mu_v, scale=sigma_v)
+ self.assertAllClose(expected_log_cdf, log_cdf_values, atol=0., rtol=1e-5)
+ self.assertAllClose(
+ np.log(expected_cdf), log_cdf_values, atol=0., rtol=1e-5)
+ self.assertAllClose(expected_cdf, cdf_values, atol=0., rtol=1e-5)
+ self.assertAllClose(
+ np.exp(expected_log_cdf), cdf_values, atol=0., rtol=1e-5)
+
+ def testStudentEntropy(self):
+ df_v = np.array([[2., 3., 7.]]) # 1x3
+ mu_v = np.array([[1., -1, 0]]) # 1x3
+ sigma_v = np.array([[1., -2., 3.]]).T # transposed => 3x1
+ with self.test_session():
+ student = student_t.StudentT(df=df_v, loc=mu_v, scale=sigma_v)
+ ent = student.entropy()
+ ent_values = ent.eval()
+
+ # Help scipy broadcast to 3x3
+ ones = np.array([[1, 1, 1]])
+ sigma_bc = np.abs(sigma_v) * ones
+ mu_bc = ones.T * mu_v
+ df_bc = ones.T * df_v
+ if not stats:
+ return
+ expected_entropy = stats.t.entropy(
+ np.reshape(df_bc, [-1]),
+ loc=np.reshape(mu_bc, [-1]),
+ scale=np.reshape(sigma_bc, [-1]))
+ expected_entropy = np.reshape(expected_entropy, df_bc.shape)
+ self.assertAllClose(expected_entropy, ent_values)
+
+ def testStudentSample(self):
+ with self.test_session():
+ df = constant_op.constant(4.)
+ mu = constant_op.constant(3.)
+ sigma = constant_op.constant(-math.sqrt(10.))
+ df_v = 4.
+ mu_v = 3.
+ sigma_v = np.sqrt(10.)
+ n = constant_op.constant(200000)
+ student = student_t.StudentT(df=df, loc=mu, scale=sigma)
+ samples = student.sample(n, seed=123456)
+ sample_values = samples.eval()
+ n_val = 200000
+ self.assertEqual(sample_values.shape, (n_val,))
+ self.assertAllClose(sample_values.mean(), mu_v, rtol=1e-2, atol=0)
+ self.assertAllClose(
+ sample_values.var(),
+ sigma_v**2 * df_v / (df_v - 2),
+ rtol=1e-2,
+ atol=0)
+ self._checkKLApprox(df_v, mu_v, sigma_v, sample_values)
+
+ # Test that sampling with the same seed twice gives the same results.
+ def testStudentSampleMultipleTimes(self):
+ with self.test_session():
+ df = constant_op.constant(4.)
+ mu = constant_op.constant(3.)
+ sigma = constant_op.constant(math.sqrt(10.))
+ n = constant_op.constant(100)
+
+ random_seed.set_random_seed(654321)
+ student = student_t.StudentT(
+ df=df, loc=mu, scale=sigma, name="student_t1")
+ samples1 = student.sample(n, seed=123456).eval()
+
+ random_seed.set_random_seed(654321)
+ student2 = student_t.StudentT(
+ df=df, loc=mu, scale=sigma, name="student_t2")
+ samples2 = student2.sample(n, seed=123456).eval()
+
+ self.assertAllClose(samples1, samples2)
+
+ def testStudentSampleSmallDfNoNan(self):
+ with self.test_session():
+ df_v = [1e-1, 1e-5, 1e-10, 1e-20]
+ df = constant_op.constant(df_v)
+ n = constant_op.constant(200000)
+ student = student_t.StudentT(df=df, loc=1., scale=1.)
+ samples = student.sample(n, seed=123456)
+ sample_values = samples.eval()
+ n_val = 200000
+ self.assertEqual(sample_values.shape, (n_val, 4))
+ self.assertTrue(np.all(np.logical_not(np.isnan(sample_values))))
+
+ def testStudentSampleMultiDimensional(self):
+ with self.test_session():
+ batch_size = 7
+ df = constant_op.constant([[3., 7.]] * batch_size)
+ mu = constant_op.constant([[3., -3.]] * batch_size)
+ sigma = constant_op.constant([[math.sqrt(10.), math.sqrt(15.)]] *
+ batch_size)
+ df_v = [3., 7.]
+ mu_v = [3., -3.]
+ sigma_v = [np.sqrt(10.), np.sqrt(15.)]
+ n = constant_op.constant(200000)
+ student = student_t.StudentT(df=df, loc=mu, scale=sigma)
+ samples = student.sample(n, seed=123456)
+ sample_values = samples.eval()
+ self.assertEqual(samples.get_shape(), (200000, batch_size, 2))
+ self.assertAllClose(
+ sample_values[:, 0, 0].mean(), mu_v[0], rtol=1e-2, atol=0)
+ self.assertAllClose(
+ sample_values[:, 0, 0].var(),
+ sigma_v[0]**2 * df_v[0] / (df_v[0] - 2),
+ rtol=1e-1,
+ atol=0)
+ self._checkKLApprox(df_v[0], mu_v[0], sigma_v[0], sample_values[:, 0, 0])
+ self.assertAllClose(
+ sample_values[:, 0, 1].mean(), mu_v[1], rtol=1e-2, atol=0)
+ self.assertAllClose(
+ sample_values[:, 0, 1].var(),
+ sigma_v[1]**2 * df_v[1] / (df_v[1] - 2),
+ rtol=1e-1,
+ atol=0)
+ self._checkKLApprox(df_v[0], mu_v[0], sigma_v[0], sample_values[:, 0, 1])
+
+ def _checkKLApprox(self, df, mu, sigma, samples):
+ n = samples.size
+ np.random.seed(137)
+ if not stats:
+ return
+ sample_scipy = stats.t.rvs(df, loc=mu, scale=sigma, size=n)
+ covg = 0.99
+ r = stats.t.interval(covg, df, loc=mu, scale=sigma)
+ bins = 100
+ hist, _ = np.histogram(samples, bins=bins, range=r)
+ hist_scipy, _ = np.histogram(sample_scipy, bins=bins, range=r)
+ self.assertGreater(hist.sum(), n * (covg - .01))
+ self.assertGreater(hist_scipy.sum(), n * (covg - .01))
+ hist_min1 = hist + 1. # put at least one item in each bucket
+ hist_norm = hist_min1 / hist_min1.sum()
+ hist_scipy_min1 = hist_scipy + 1. # put at least one item in each bucket
+ hist_scipy_norm = hist_scipy_min1 / hist_scipy_min1.sum()
+ kl_appx = np.sum(np.log(hist_scipy_norm / hist_norm) * hist_scipy_norm)
+ self.assertLess(kl_appx, 1)
+
+ def testBroadcastingParams(self):
+
+ def _check(student):
+ self.assertEqual(student.mean().get_shape(), (3,))
+ self.assertEqual(student.variance().get_shape(), (3,))
+ self.assertEqual(student.entropy().get_shape(), (3,))
+ self.assertEqual(student.log_prob(2.).get_shape(), (3,))
+ self.assertEqual(student.prob(2.).get_shape(), (3,))
+ self.assertEqual(student.sample(37, seed=123456).get_shape(), (37, 3,))
+
+ _check(student_t.StudentT(df=[2., 3., 4.,], loc=2., scale=1.))
+ _check(student_t.StudentT(df=7., loc=[2., 3., 4.,], scale=1.))
+ _check(student_t.StudentT(df=7., loc=3., scale=[2., 3., 4.,]))
+
+ def testBroadcastingPdfArgs(self):
+
+ def _assert_shape(student, arg, shape):
+ self.assertEqual(student.log_prob(arg).get_shape(), shape)
+ self.assertEqual(student.prob(arg).get_shape(), shape)
+
+ def _check(student):
+ _assert_shape(student, 2., (3,))
+ xs = np.array([2., 3., 4.], dtype=np.float32)
+ _assert_shape(student, xs, (3,))
+ xs = np.array([xs])
+ _assert_shape(student, xs, (1, 3))
+ xs = xs.T
+ _assert_shape(student, xs, (3, 3))
+
+ _check(student_t.StudentT(df=[2., 3., 4.,], loc=2., scale=1.))
+ _check(student_t.StudentT(df=7., loc=[2., 3., 4.,], scale=1.))
+ _check(student_t.StudentT(df=7., loc=3., scale=[2., 3., 4.,]))
+
+ def _check2d(student):
+ _assert_shape(student, 2., (1, 3))
+ xs = np.array([2., 3., 4.], dtype=np.float32)
+ _assert_shape(student, xs, (1, 3))
+ xs = np.array([xs])
+ _assert_shape(student, xs, (1, 3))
+ xs = xs.T
+ _assert_shape(student, xs, (3, 3))
+
+ _check2d(student_t.StudentT(df=[[2., 3., 4.,]], loc=2., scale=1.))
+ _check2d(student_t.StudentT(df=7., loc=[[2., 3., 4.,]], scale=1.))
+ _check2d(student_t.StudentT(df=7., loc=3., scale=[[2., 3., 4.,]]))
+
+ def _check2d_rows(student):
+ _assert_shape(student, 2., (3, 1))
+ xs = np.array([2., 3., 4.], dtype=np.float32) # (3,)
+ _assert_shape(student, xs, (3, 3))
+ xs = np.array([xs]) # (1,3)
+ _assert_shape(student, xs, (3, 3))
+ xs = xs.T # (3,1)
+ _assert_shape(student, xs, (3, 1))
+
+ _check2d_rows(student_t.StudentT(df=[[2.], [3.], [4.]], loc=2., scale=1.))
+ _check2d_rows(student_t.StudentT(df=7., loc=[[2.], [3.], [4.]], scale=1.))
+ _check2d_rows(student_t.StudentT(df=7., loc=3., scale=[[2.], [3.], [4.]]))
+
+ def testMeanAllowNanStatsIsFalseWorksWhenAllBatchMembersAreDefined(self):
+ with self.test_session():
+ mu = [1., 3.3, 4.4]
+ student = student_t.StudentT(df=[3., 5., 7.], loc=mu, scale=[3., 2., 1.])
+ mean = student.mean().eval()
+ self.assertAllClose([1., 3.3, 4.4], mean)
+
+ def testMeanAllowNanStatsIsFalseRaisesWhenBatchMemberIsUndefined(self):
+ with self.test_session():
+ mu = [1., 3.3, 4.4]
+ student = student_t.StudentT(
+ df=[0.5, 5., 7.], loc=mu, scale=[3., 2., 1.],
+ allow_nan_stats=False)
+ with self.assertRaisesOpError("x < y"):
+ student.mean().eval()
+
+ def testMeanAllowNanStatsIsTrueReturnsNaNForUndefinedBatchMembers(self):
+ with self.test_session():
+ mu = [-2, 0., 1., 3.3, 4.4]
+ sigma = [5., 4., 3., 2., 1.]
+ student = student_t.StudentT(
+ df=[0.5, 1., 3., 5., 7.], loc=mu, scale=sigma,
+ allow_nan_stats=True)
+ mean = student.mean().eval()
+ self.assertAllClose([np.nan, np.nan, 1., 3.3, 4.4], mean)
+
+ def testVarianceAllowNanStatsTrueReturnsNaNforUndefinedBatchMembers(self):
+ with self.test_session():
+ # df = 0.5 ==> undefined mean ==> undefined variance.
+ # df = 1.5 ==> infinite variance.
+ df = [0.5, 1.5, 3., 5., 7.]
+ mu = [-2, 0., 1., 3.3, 4.4]
+ sigma = [5., 4., 3., 2., 1.]
+ student = student_t.StudentT(
+ df=df, loc=mu, scale=sigma, allow_nan_stats=True)
+ var = student.variance().eval()
+ ## scipy uses inf for variance when the mean is undefined. When mean is
+ # undefined we say variance is undefined as well. So test the first
+ # member of var, making sure it is NaN, then replace with inf and compare
+ # to scipy.
+ self.assertTrue(np.isnan(var[0]))
+ var[0] = np.inf
+
+ if not stats:
+ return
+ expected_var = [
+ stats.t.var(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma)
+ ]
+ self.assertAllClose(expected_var, var)
+
+ def testVarianceAllowNanStatsFalseGivesCorrectValueForDefinedBatchMembers(
+ self):
+ with self.test_session():
+ # df = 1.5 ==> infinite variance.
+ df = [1.5, 3., 5., 7.]
+ mu = [0., 1., 3.3, 4.4]
+ sigma = [4., 3., 2., 1.]
+ student = student_t.StudentT(df=df, loc=mu, scale=sigma)
+ var = student.variance().eval()
+
+ if not stats:
+ return
+ expected_var = [
+ stats.t.var(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma)
+ ]
+ self.assertAllClose(expected_var, var)
+
+ def testVarianceAllowNanStatsFalseRaisesForUndefinedBatchMembers(self):
+ with self.test_session():
+ # df <= 1 ==> variance not defined
+ student = student_t.StudentT(
+ df=1., loc=0., scale=1., allow_nan_stats=False)
+ with self.assertRaisesOpError("x < y"):
+ student.variance().eval()
+
+ with self.test_session():
+ # df <= 1 ==> variance not defined
+ student = student_t.StudentT(
+ df=0.5, loc=0., scale=1., allow_nan_stats=False)
+ with self.assertRaisesOpError("x < y"):
+ student.variance().eval()
+
+ def testStd(self):
+ with self.test_session():
+ # Defined for all batch members.
+ df = [3.5, 5., 3., 5., 7.]
+ mu = [-2.2]
+ sigma = [5., 4., 3., 2., 1.]
+ student = student_t.StudentT(df=df, loc=mu, scale=sigma)
+ # Test broadcast of mu across shape of df/sigma
+ stddev = student.stddev().eval()
+ mu *= len(df)
+
+ if not stats:
+ return
+ expected_stddev = [
+ stats.t.std(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma)
+ ]
+ self.assertAllClose(expected_stddev, stddev)
+
+ def testMode(self):
+ with self.test_session():
+ df = [0.5, 1., 3]
+ mu = [-1, 0., 1]
+ sigma = [5., 4., 3.]
+ student = student_t.StudentT(df=df, loc=mu, scale=sigma)
+ # Test broadcast of mu across shape of df/sigma
+ mode = student.mode().eval()
+ self.assertAllClose([-1., 0, 1], mode)
+
+ def testPdfOfSample(self):
+ with self.test_session() as sess:
+ student = student_t.StudentT(df=3., loc=np.pi, scale=1.)
+ num = 20000
+ samples = student.sample(num, seed=123456)
+ pdfs = student.prob(samples)
+ mean = student.mean()
+ mean_pdf = student.prob(student.mean())
+ sample_vals, pdf_vals, mean_val, mean_pdf_val = sess.run(
+ [samples, pdfs, student.mean(), mean_pdf])
+ self.assertEqual(samples.get_shape(), (num,))
+ self.assertEqual(pdfs.get_shape(), (num,))
+ self.assertEqual(mean.get_shape(), ())
+ self.assertNear(np.pi, np.mean(sample_vals), err=0.02)
+ self.assertNear(np.pi, mean_val, err=1e-6)
+ # Verify integral over sample*pdf ~= 1.
+ self._assertIntegral(sample_vals, pdf_vals, err=2e-3)
+ if not stats:
+ return
+ self.assertNear(stats.t.pdf(np.pi, 3., loc=np.pi), mean_pdf_val, err=1e-6)
+
+ def testPdfOfSampleMultiDims(self):
+ with self.test_session() as sess:
+ student = student_t.StudentT(df=[7., 11.], loc=[[5.], [6.]], scale=3.)
+ self.assertAllEqual([], student.event_shape)
+ self.assertAllEqual([], student.event_shape_tensor().eval())
+ self.assertAllEqual([2, 2], student.batch_shape)
+ self.assertAllEqual([2, 2], student.batch_shape_tensor().eval())
+ num = 50000
+ samples = student.sample(num, seed=123456)
+ pdfs = student.prob(samples)
+ sample_vals, pdf_vals = sess.run([samples, pdfs])
+ self.assertEqual(samples.get_shape(), (num, 2, 2))
+ self.assertEqual(pdfs.get_shape(), (num, 2, 2))
+ self.assertNear(5., np.mean(sample_vals[:, 0, :]), err=.03)
+ self.assertNear(6., np.mean(sample_vals[:, 1, :]), err=.03)
+ self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.02)
+ self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.02)
+ self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.02)
+ self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.02)
+ if not stats:
+ return
+ self.assertNear(
+ stats.t.var(7., loc=0., scale=3.), # loc d.n. effect var
+ np.var(sample_vals[:, :, 0]),
+ err=.4)
+ self.assertNear(
+ stats.t.var(11., loc=0., scale=3.), # loc d.n. effect var
+ np.var(sample_vals[:, :, 1]),
+ err=.4)
+
+ def _assertIntegral(self, sample_vals, pdf_vals, err=1.5e-3):
+ s_p = zip(sample_vals, pdf_vals)
+ prev = (sample_vals.min() - 1000, 0)
+ total = 0
+ for k in sorted(s_p, key=lambda x: x[0]):
+ pair_pdf = (k[1] + prev[1]) / 2
+ total += (k[0] - prev[0]) * pair_pdf
+ prev = k
+ self.assertNear(1., total, err=err)
+
+ def testNegativeDofFails(self):
+ with self.test_session():
+ student = student_t.StudentT(df=[2, -5.], loc=0., scale=1.,
+ validate_args=True, name="S")
+ with self.assertRaisesOpError(r"Condition x > 0 did not hold"):
+ student.mean().eval()
+
+ def testStudentTWithAbsDfSoftplusScale(self):
+ with self.test_session():
+ df = constant_op.constant([-3.2, -4.6])
+ mu = constant_op.constant([-4.2, 3.4])
+ sigma = constant_op.constant([-6.4, -8.8])
+ student = student_t.StudentTWithAbsDfSoftplusScale(
+ df=df, loc=mu, scale=sigma)
+ self.assertAllClose(
+ math_ops.floor(math_ops.abs(df)).eval(), student.df.eval())
+ self.assertAllClose(mu.eval(), student.loc.eval())
+ self.assertAllClose(nn_ops.softplus(sigma).eval(), student.scale.eval())
+
+
+if __name__ == "__main__":
+ test.main()