diff options
author | 2017-01-31 17:18:37 -0800 | |
---|---|---|
committer | 2017-01-31 17:39:00 -0800 | |
commit | 0e9cebfc8c35781127ff91246598b87ca8ce0aa5 (patch) | |
tree | 189d62189e802b48e9950852aa099233f7b514cf | |
parent | bbadfff5834f16d6705b09bce26c2b0972c5dc70 (diff) |
BREAKING CHANGE: Standardize "concentration", "rate", "total_count" distribution arguments.
BUGFIX: Correct undefined mode in dirichlet.mode.
BUGFIX: Correct broadcasting in dirichletmultinomial.mean.
Change: 146187147
20 files changed, 1339 insertions, 1007 deletions
diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index 00946a1f52..9ba3a60044 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -30,16 +30,16 @@ initialized with parameters that define the distributions. @@Bernoulli @@BernoulliWithSigmoidProbs @@Beta -@@BetaWithSoftplusAB +@@BetaWithSoftplusConcentration @@Categorical @@Chi2 @@Chi2WithAbsDf @@Exponential -@@ExponentialWithSoftplusLam +@@ExponentialWithSoftplusRate @@Gamma -@@GammaWithSoftplusAlphaBeta +@@GammaWithSoftplusConcentrationRate @@InverseGamma -@@InverseGammaWithSoftplusAlphaBeta +@@InverseGammaWithSoftplusConcentrationRate @@Laplace @@LaplaceWithSoftplusScale @@Normal diff --git a/tensorflow/contrib/distributions/python/kernel_tests/beta_test.py b/tensorflow/contrib/distributions/python/kernel_tests/beta_test.py index 17c3320d70..f524986cec 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/beta_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/beta_test.py @@ -69,16 +69,16 @@ class BetaTest(test.TestCase): b = [[2., 4, 3]] with self.test_session(): dist = beta_lib.Beta(a, b) - self.assertEqual([1, 3], dist.a.get_shape()) - self.assertAllClose(a, dist.a.eval()) + self.assertEqual([1, 3], dist.concentration1.get_shape()) + self.assertAllClose(a, dist.concentration1.eval()) def testBetaProperty(self): a = [[1., 2, 3]] b = [[2., 4, 3]] with self.test_session(): dist = beta_lib.Beta(a, b) - self.assertEqual([1, 3], dist.b.get_shape()) - self.assertAllClose(b, dist.b.eval()) + self.assertEqual([1, 3], dist.concentration0.get_shape()) + self.assertAllClose(b, dist.concentration0.eval()) def testPdfXProper(self): a = [[1., 2, 3]] @@ -88,11 +88,11 @@ class BetaTest(test.TestCase): dist.prob([.1, .3, .6]).eval() dist.prob([.2, .3, .5]).eval() # Either condition can trigger. - with self.assertRaisesOpError("(Condition x > 0.*|Condition x < y.*)"): - dist.prob([-1., 1, 1]).eval() - with self.assertRaisesOpError("Condition x.*"): - dist.prob([0., 1, 1]).eval() - with self.assertRaisesOpError("Condition x < y.*"): + with self.assertRaisesOpError("sample must be positive"): + dist.prob([-1., 0.1, 0.5]).eval() + with self.assertRaisesOpError("sample must be positive"): + dist.prob([0., 0.1, 0.5]).eval() + with self.assertRaisesOpError("sample must be no larger than `1`"): dist.prob([.1, .2, 1.2]).eval() def testPdfTwoBatches(self): @@ -247,8 +247,7 @@ class BetaTest(test.TestCase): stats.kstest( # Beta is a univariate distribution. sample_values, - stats.beta( - a=1., b=2.).cdf)[0], + stats.beta(a=1., b=2.).cdf)[0], 0.01) # The standard error of the sample mean is 1 / (sqrt(18 * n)) self.assertAllClose( @@ -264,11 +263,15 @@ class BetaTest(test.TestCase): n_val = 100 random_seed.set_random_seed(654321) - beta1 = beta_lib.Beta(a=a_val, b=b_val, name="beta1") + beta1 = beta_lib.Beta(concentration1=a_val, + concentration0=b_val, + name="beta1") samples1 = beta1.sample(n_val, seed=123456).eval() random_seed.set_random_seed(654321) - beta2 = beta_lib.Beta(a=a_val, b=b_val, name="beta2") + beta2 = beta_lib.Beta(concentration1=a_val, + concentration0=b_val, + name="beta2") samples2 = beta2.sample(n_val, seed=123456).eval() self.assertAllClose(samples1, samples2) @@ -312,12 +315,12 @@ class BetaTest(test.TestCase): self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x) self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0) - def testBetaWithSoftplusAB(self): + def testBetaWithSoftplusConcentration(self): with self.test_session(): a, b = -4.2, -9.1 - dist = beta_lib.BetaWithSoftplusAB(a, b) - self.assertAllClose(nn_ops.softplus(a).eval(), dist.a.eval()) - self.assertAllClose(nn_ops.softplus(b).eval(), dist.b.eval()) + dist = beta_lib.BetaWithSoftplusConcentration(a, b) + self.assertAllClose(nn_ops.softplus(a).eval(), dist.concentration1.eval()) + self.assertAllClose(nn_ops.softplus(b).eval(), dist.concentration0.eval()) def testBetaBetaKL(self): with self.test_session() as sess: @@ -326,16 +329,18 @@ class BetaTest(test.TestCase): b1 = 6.0 * np.random.random(size=shape) + 1e-4 a2 = 6.0 * np.random.random(size=shape) + 1e-4 b2 = 6.0 * np.random.random(size=shape) + 1e-4 - # Take inverse softplus of values to test BetaWithSoftplusAB + # Take inverse softplus of values to test BetaWithSoftplusConcentration a1_sp = np.log(np.exp(a1) - 1.0) b1_sp = np.log(np.exp(b1) - 1.0) a2_sp = np.log(np.exp(a2) - 1.0) b2_sp = np.log(np.exp(b2) - 1.0) - d1 = beta_lib.Beta(a=a1, b=b1) - d2 = beta_lib.Beta(a=a2, b=b2) - d1_sp = beta_lib.BetaWithSoftplusAB(a=a1_sp, b=b1_sp) - d2_sp = beta_lib.BetaWithSoftplusAB(a=a2_sp, b=b2_sp) + d1 = beta_lib.Beta(concentration1=a1, concentration0=b1) + d2 = beta_lib.Beta(concentration1=a2, concentration0=b2) + d1_sp = beta_lib.BetaWithSoftplusConcentration(concentration1=a1_sp, + concentration0=b1_sp) + d2_sp = beta_lib.BetaWithSoftplusConcentration(concentration1=a2_sp, + concentration0=b2_sp) kl_expected = (special.betaln(a2, b2) - special.betaln(a1, b1) + (a1 - a2) * special.digamma(a1) + diff --git a/tensorflow/contrib/distributions/python/kernel_tests/chi2_test.py b/tensorflow/contrib/distributions/python/kernel_tests/chi2_test.py index 4157f0af25..75d48791ec 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/chi2_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/chi2_test.py @@ -88,7 +88,8 @@ class Chi2Test(test.TestCase): df_v = np.array([-1.3, -3.2, 5], dtype=np.float64) chi2 = chi2_lib.Chi2WithAbsDf(df=df_v) self.assertAllClose( - math_ops.floor(math_ops.abs(df_v)).eval(), chi2.df.eval()) + math_ops.floor(math_ops.abs(df_v)).eval(), + chi2.df.eval()) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_multinomial_test.py b/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_multinomial_test.py index 32230c8a06..235ce20945 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_multinomial_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_multinomial_test.py @@ -55,15 +55,15 @@ class DirichletMultinomialTest(test.TestCase): n = [[5.]] with self.test_session(): dist = ds.DirichletMultinomial(n, alpha) - self.assertEqual([1, 1], dist.n.get_shape()) - self.assertAllClose(n, dist.n.eval()) + self.assertEqual([1, 1], dist.total_count.get_shape()) + self.assertAllClose(n, dist.total_count.eval()) def testAlphaProperty(self): alpha = [[1., 2, 3]] with self.test_session(): dist = ds.DirichletMultinomial(1, alpha) - self.assertEqual([1, 3], dist.alpha.get_shape()) - self.assertAllClose(alpha, dist.alpha.eval()) + self.assertEqual([1, 3], dist.concentration.get_shape()) + self.assertAllClose(alpha, dist.concentration.eval()) def testPmfNandCountsAgree(self): alpha = [[1., 2, 3]] @@ -72,9 +72,10 @@ class DirichletMultinomialTest(test.TestCase): dist = ds.DirichletMultinomial(n, alpha, validate_args=True) dist.prob([2., 3, 0]).eval() dist.prob([3., 0, 2]).eval() - with self.assertRaisesOpError("Condition x >= 0.*"): + with self.assertRaisesOpError("counts must be non-negative"): dist.prob([-1., 4, 2]).eval() - with self.assertRaisesOpError("counts do not sum to n"): + with self.assertRaisesOpError( + "counts last-dimension must sum to `self.total_count`"): dist.prob([3., 3, 0]).eval() def testPmfNonIntegerCounts(self): @@ -86,7 +87,8 @@ class DirichletMultinomialTest(test.TestCase): dist.prob([3., 0, 2]).eval() dist.prob([3.0, 0, 2.0]).eval() # Both equality and integer checking fail. - with self.assertRaisesOpError("Condition x == y.*"): + with self.assertRaisesOpError( + "counts cannot contain fractional components"): dist.prob([1.0, 2.5, 1.5]).eval() dist = ds.DirichletMultinomial(n, alpha, validate_args=False) dist.prob([1., 2., 3.]).eval() @@ -138,7 +140,7 @@ class DirichletMultinomialTest(test.TestCase): dist = ds.DirichletMultinomial([1.], alpha) pmf = dist.prob(counts) self.assertAllClose([1 / 3., 2 / 3.], pmf.eval()) - self.assertEqual((2), pmf.get_shape()) + self.assertAllEqual([2], pmf.get_shape()) def testPmfAlphaStretchedInBroadcastWhenLowerRank(self): # The probabilities of one vote falling into class k is the mean for class @@ -148,7 +150,7 @@ class DirichletMultinomialTest(test.TestCase): counts = [[1., 0], [0., 1]] pmf = ds.DirichletMultinomial(1., alpha).prob(counts) self.assertAllClose([1 / 3., 2 / 3.], pmf.eval()) - self.assertEqual((2), pmf.get_shape()) + self.assertAllEqual([2], pmf.get_shape()) def testPmfCountsStretchedInBroadcastWhenSameRank(self): # The probabilities of one vote falling into class k is the mean for class @@ -158,7 +160,7 @@ class DirichletMultinomialTest(test.TestCase): counts = [[1., 0]] pmf = ds.DirichletMultinomial([1., 1.], alpha).prob(counts) self.assertAllClose([1 / 3., 2 / 5.], pmf.eval()) - self.assertEqual((2), pmf.get_shape()) + self.assertAllEqual([2], pmf.get_shape()) def testPmfCountsStretchedInBroadcastWhenLowerRank(self): # The probabilities of one vote falling into class k is the mean for class @@ -168,7 +170,7 @@ class DirichletMultinomialTest(test.TestCase): counts = [1., 0] pmf = ds.DirichletMultinomial(1., alpha).prob(counts) self.assertAllClose([1 / 3., 2 / 5.], pmf.eval()) - self.assertEqual((2), pmf.get_shape()) + self.assertAllEqual([2], pmf.get_shape()) def testPmfForOneVoteIsTheMeanWithOneRecordInput(self): # The probabilities of one vote falling into class k is the mean for class @@ -176,15 +178,15 @@ class DirichletMultinomialTest(test.TestCase): alpha = [1., 2, 3] with self.test_session(): for class_num in range(3): - counts = np.zeros((3), dtype=np.float32) + counts = np.zeros([3], dtype=np.float32) counts[class_num] = 1 dist = ds.DirichletMultinomial(1., alpha) mean = dist.mean().eval() pmf = dist.prob(counts).eval() self.assertAllClose(mean[class_num], pmf) - self.assertTupleEqual((3,), mean.shape) - self.assertTupleEqual((), pmf.shape) + self.assertAllEqual([3], mean.shape) + self.assertAllEqual([], pmf.shape) def testMeanDoubleTwoVotes(self): # The probabilities of two votes falling into class k for @@ -193,9 +195,9 @@ class DirichletMultinomialTest(test.TestCase): alpha = [1., 2, 3] with self.test_session(): for class_num in range(3): - counts_one = np.zeros((3), dtype=np.float32) + counts_one = np.zeros([3], dtype=np.float32) counts_one[class_num] = 1. - counts_two = np.zeros((3), dtype=np.float32) + counts_two = np.zeros([3], dtype=np.float32) counts_two[class_num] = 2 dist1 = ds.DirichletMultinomial(1., alpha) @@ -205,7 +207,7 @@ class DirichletMultinomialTest(test.TestCase): mean2 = dist2.mean().eval() self.assertAllClose(mean2[class_num], 2 * mean1[class_num]) - self.assertTupleEqual((3,), mean1.shape) + self.assertAllEqual([3], mean1.shape) def testCovarianceFromSampling(self): # We will test mean, cov, var, stddev on a DirichletMultinomial constructed @@ -279,7 +281,7 @@ class DirichletMultinomialTest(test.TestCase): covariance = dist.covariance() expected_covariance = n * (n + alpha_0) / (1 + alpha_0) * shared_matrix - self.assertEqual((2, 2), covariance.get_shape()) + self.assertEqual([2, 2], covariance.get_shape()) self.assertAllClose(expected_covariance, covariance.eval()) def testCovarianceNAlphaBroadcast(self): @@ -415,7 +417,8 @@ class DirichletMultinomialTest(test.TestCase): def testSampleUnbiasedNonScalarBatch(self): with self.test_session() as sess: dist = ds.DirichletMultinomial( - n=5., alpha=2. * self._rng.rand(4, 3, 2).astype(np.float32)) + total_count=5., + concentration=2. * self._rng.rand(4, 3, 2).astype(np.float32)) n = int(3e3) x = dist.sample(n, seed=0) sample_mean = math_ops.reduce_mean(x, 0) @@ -443,7 +446,8 @@ class DirichletMultinomialTest(test.TestCase): def testSampleUnbiasedScalarBatch(self): with self.test_session() as sess: dist = ds.DirichletMultinomial( - n=5., alpha=2. * self._rng.rand(4).astype(np.float32)) + total_count=5., + concentration=2. * self._rng.rand(4).astype(np.float32)) n = int(5e3) x = dist.sample(n, seed=0) sample_mean = math_ops.reduce_mean(x, 0) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_test.py b/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_test.py index 919905d245..cd634da09d 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_test.py @@ -46,12 +46,12 @@ class DirichletTest(test.TestCase): self.assertEqual(tensor_shape.TensorShape([2]), dist.event_shape) self.assertEqual(tensor_shape.TensorShape([3, 2]), dist.batch_shape) - def testAlphaProperty(self): + def testConcentrationProperty(self): alpha = [[1., 2, 3]] with self.test_session(): dist = dirichlet_lib.Dirichlet(alpha) - self.assertEqual([1, 3], dist.alpha.get_shape()) - self.assertAllClose(alpha, dist.alpha.eval()) + self.assertEqual([1, 3], dist.concentration.get_shape()) + self.assertAllClose(alpha, dist.concentration.eval()) def testPdfXProper(self): alpha = [[1., 2, 3]] @@ -60,11 +60,12 @@ class DirichletTest(test.TestCase): dist.prob([.1, .3, .6]).eval() dist.prob([.2, .3, .5]).eval() # Either condition can trigger. - with self.assertRaisesOpError("Condition x > 0.*|Condition x < y.*"): - dist.prob([-1., 1, 1]).eval() - with self.assertRaisesOpError("Condition x > 0.*"): + with self.assertRaisesOpError("samples must be positive"): + dist.prob([-1., 1.5, 0.5]).eval() + with self.assertRaisesOpError("samples must be positive"): dist.prob([0., .1, .9]).eval() - with self.assertRaisesOpError("Condition x ~= y.*"): + with self.assertRaisesOpError( + "sample last-dimension must sum to `1`"): dist.prob([.1, .2, .8]).eval() def testPdfZeroBatches(self): @@ -128,12 +129,12 @@ class DirichletTest(test.TestCase): self.assertAllClose([1., 3. / 2], pdf.eval()) self.assertEqual((2), pdf.get_shape()) - def testDirichletMean(self): + def testMean(self): with self.test_session(): alpha = [1., 2, 3] expected_mean = stats.dirichlet.mean(alpha) - dirichlet = dirichlet_lib.Dirichlet(alpha=alpha) - self.assertEqual(dirichlet.mean().get_shape(), (3,)) + dirichlet = dirichlet_lib.Dirichlet(concentration=alpha) + self.assertEqual(dirichlet.mean().get_shape(), [3]) self.assertAllClose(dirichlet.mean().eval(), expected_mean) def testCovarianceFromSampling(self): @@ -172,51 +173,52 @@ class DirichletTest(test.TestCase): self.assertAllClose(sample_var_, analytic_var, atol=0., rtol=0.03) self.assertAllClose(sample_stddev_, analytic_stddev, atol=0., rtol=0.02) - def testDirichletCovariance(self): + def testVariance(self): with self.test_session(): alpha = [1., 2, 3] denominator = np.sum(alpha)**2 * (np.sum(alpha) + 1) expected_covariance = np.diag(stats.dirichlet.var(alpha)) expected_covariance += [[0., -2, -3], [-2, 0, -6], [-3, -6, 0]] / denominator - dirichlet = dirichlet_lib.Dirichlet(alpha=alpha) + dirichlet = dirichlet_lib.Dirichlet(concentration=alpha) self.assertEqual(dirichlet.covariance().get_shape(), (3, 3)) self.assertAllClose(dirichlet.covariance().eval(), expected_covariance) - def testDirichletMode(self): + def testMode(self): with self.test_session(): alpha = np.array([1.1, 2, 3]) expected_mode = (alpha - 1) / (np.sum(alpha) - 3) - dirichlet = dirichlet_lib.Dirichlet(alpha=alpha) - self.assertEqual(dirichlet.mode().get_shape(), (3,)) + dirichlet = dirichlet_lib.Dirichlet(concentration=alpha) + self.assertEqual(dirichlet.mode().get_shape(), [3]) self.assertAllClose(dirichlet.mode().eval(), expected_mode) - def testDirichletModeInvalid(self): + def testModeInvalid(self): with self.test_session(): alpha = np.array([1., 2, 3]) - dirichlet = dirichlet_lib.Dirichlet(alpha=alpha, allow_nan_stats=False) + dirichlet = dirichlet_lib.Dirichlet(concentration=alpha, + allow_nan_stats=False) with self.assertRaisesOpError("Condition x < y.*"): dirichlet.mode().eval() - def testDirichletModeEnableAllowNanStats(self): + def testModeEnableAllowNanStats(self): with self.test_session(): alpha = np.array([1., 2, 3]) - dirichlet = dirichlet_lib.Dirichlet(alpha=alpha, allow_nan_stats=True) - expected_mode = (alpha - 1) / (np.sum(alpha) - 3) - expected_mode[0] = np.nan + dirichlet = dirichlet_lib.Dirichlet(concentration=alpha, + allow_nan_stats=True) + expected_mode = np.zeros_like(alpha) + np.nan - self.assertEqual(dirichlet.mode().get_shape(), (3,)) + self.assertEqual(dirichlet.mode().get_shape(), [3]) self.assertAllClose(dirichlet.mode().eval(), expected_mode) - def testDirichletEntropy(self): + def testEntropy(self): with self.test_session(): alpha = [1., 2, 3] expected_entropy = stats.dirichlet.entropy(alpha) - dirichlet = dirichlet_lib.Dirichlet(alpha=alpha) + dirichlet = dirichlet_lib.Dirichlet(concentration=alpha) self.assertEqual(dirichlet.entropy().get_shape(), ()) self.assertAllClose(dirichlet.entropy().eval(), expected_entropy) - def testDirichletSample(self): + def testSample(self): with self.test_session(): alpha = [1., 2] dirichlet = dirichlet_lib.Dirichlet(alpha) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/exponential_test.py b/tensorflow/contrib/distributions/python/kernel_tests/exponential_test.py index 5f9a74405a..6171202413 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/exponential_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/exponential_test.py @@ -35,7 +35,7 @@ class ExponentialTest(test.TestCase): lam = constant_op.constant([2.0] * batch_size) lam_v = 2.0 x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) - exponential = exponential_lib.Exponential(lam=lam) + exponential = exponential_lib.Exponential(rate=lam) expected_log_pdf = stats.expon.logpdf(x, scale=1 / lam_v) log_pdf = exponential.log_prob(x) @@ -53,7 +53,7 @@ class ExponentialTest(test.TestCase): lam_v = 2.0 x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) - exponential = exponential_lib.Exponential(lam=lam) + exponential = exponential_lib.Exponential(rate=lam) expected_cdf = stats.expon.cdf(x, scale=1 / lam_v) cdf = exponential.cdf(x) @@ -64,7 +64,7 @@ class ExponentialTest(test.TestCase): with session.Session(): lam_v = np.array([1.0, 4.0, 2.5]) expected_mean = stats.expon.mean(scale=1 / lam_v) - exponential = exponential_lib.Exponential(lam=lam_v) + exponential = exponential_lib.Exponential(rate=lam_v) self.assertEqual(exponential.mean().get_shape(), (3,)) self.assertAllClose(exponential.mean().eval(), expected_mean) @@ -72,7 +72,7 @@ class ExponentialTest(test.TestCase): with session.Session(): lam_v = np.array([1.0, 4.0, 2.5]) expected_variance = stats.expon.var(scale=1 / lam_v) - exponential = exponential_lib.Exponential(lam=lam_v) + exponential = exponential_lib.Exponential(rate=lam_v) self.assertEqual(exponential.variance().get_shape(), (3,)) self.assertAllClose(exponential.variance().eval(), expected_variance) @@ -80,7 +80,7 @@ class ExponentialTest(test.TestCase): with session.Session(): lam_v = np.array([1.0, 4.0, 2.5]) expected_entropy = stats.expon.entropy(scale=1 / lam_v) - exponential = exponential_lib.Exponential(lam=lam_v) + exponential = exponential_lib.Exponential(rate=lam_v) self.assertEqual(exponential.entropy().get_shape(), (3,)) self.assertAllClose(exponential.entropy().eval(), expected_entropy) @@ -89,7 +89,7 @@ class ExponentialTest(test.TestCase): lam = constant_op.constant([3.0, 4.0]) lam_v = [3.0, 4.0] n = constant_op.constant(100000) - exponential = exponential_lib.Exponential(lam=lam) + exponential = exponential_lib.Exponential(rate=lam) samples = exponential.sample(n, seed=137) sample_values = samples.eval() @@ -107,7 +107,7 @@ class ExponentialTest(test.TestCase): lam_v = [3.0, 22.0] lam = constant_op.constant([lam_v] * batch_size) - exponential = exponential_lib.Exponential(lam=lam) + exponential = exponential_lib.Exponential(rate=lam) n = 100000 samples = exponential.sample(n, seed=138) @@ -128,11 +128,12 @@ class ExponentialTest(test.TestCase): stats.expon(scale=1.0 / lam_v[i]).cdf)[0], 0.01) - def testExponentialWithSoftplusLam(self): + def testExponentialWithSoftplusRate(self): with self.test_session(): lam = [-2.2, -3.4] - exponential = exponential_lib.ExponentialWithSoftplusLam(lam=lam) - self.assertAllClose(nn_ops.softplus(lam).eval(), exponential.lam.eval()) + exponential = exponential_lib.ExponentialWithSoftplusRate(rate=lam) + self.assertAllClose(nn_ops.softplus(lam).eval(), + exponential.rate.eval()) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/gamma_test.py b/tensorflow/contrib/distributions/python/kernel_tests/gamma_test.py index c5c071b49d..fd62710237 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/gamma_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/gamma_test.py @@ -37,7 +37,7 @@ class GammaTest(test.TestCase): with self.test_session(): alpha = constant_op.constant([3.0] * 5) beta = constant_op.constant(11.0) - gamma = gamma_lib.Gamma(alpha=alpha, beta=beta) + gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) self.assertEqual(gamma.batch_shape_tensor().eval(), (5,)) self.assertEqual(gamma.batch_shape, tensor_shape.TensorShape([5])) @@ -52,7 +52,7 @@ class GammaTest(test.TestCase): alpha_v = 2.0 beta_v = 3.0 x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) - gamma = gamma_lib.Gamma(alpha=alpha, beta=beta) + gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v) log_pdf = gamma.log_prob(x) self.assertEqual(log_pdf.get_shape(), (6,)) @@ -70,7 +70,7 @@ class GammaTest(test.TestCase): alpha_v = np.array([2.0, 4.0]) beta_v = np.array([3.0, 4.0]) x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T - gamma = gamma_lib.Gamma(alpha=alpha, beta=beta) + gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v) log_pdf = gamma.log_prob(x) log_pdf_values = log_pdf.eval() @@ -90,7 +90,7 @@ class GammaTest(test.TestCase): alpha_v = np.array([2.0, 4.0]) beta_v = 3.0 x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T - gamma = gamma_lib.Gamma(alpha=alpha, beta=beta) + gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v) log_pdf = gamma.log_prob(x) log_pdf_values = log_pdf.eval() @@ -111,7 +111,7 @@ class GammaTest(test.TestCase): beta_v = 3.0 x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) - gamma = gamma_lib.Gamma(alpha=alpha, beta=beta) + gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) expected_cdf = stats.gamma.cdf(x, alpha_v, scale=1 / beta_v) cdf = gamma.cdf(x) @@ -122,7 +122,7 @@ class GammaTest(test.TestCase): with self.test_session(): alpha_v = np.array([1.0, 3.0, 2.5]) beta_v = np.array([1.0, 4.0, 5.0]) - gamma = gamma_lib.Gamma(alpha=alpha_v, beta=beta_v) + gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) expected_means = stats.gamma.mean(alpha_v, scale=1 / beta_v) self.assertEqual(gamma.mean().get_shape(), (3,)) self.assertAllClose(gamma.mean().eval(), expected_means) @@ -131,7 +131,7 @@ class GammaTest(test.TestCase): with self.test_session(): alpha_v = np.array([5.5, 3.0, 2.5]) beta_v = np.array([1.0, 4.0, 5.0]) - gamma = gamma_lib.Gamma(alpha=alpha_v, beta=beta_v) + gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) expected_modes = (alpha_v - 1) / beta_v self.assertEqual(gamma.mode().get_shape(), (3,)) self.assertAllClose(gamma.mode().eval(), expected_modes) @@ -141,7 +141,9 @@ class GammaTest(test.TestCase): # Mode will not be defined for the first entry. alpha_v = np.array([0.5, 3.0, 2.5]) beta_v = np.array([1.0, 4.0, 5.0]) - gamma = gamma_lib.Gamma(alpha=alpha_v, beta=beta_v, allow_nan_stats=False) + gamma = gamma_lib.Gamma(concentration=alpha_v, + rate=beta_v, + allow_nan_stats=False) with self.assertRaisesOpError("x < y"): gamma.mode().eval() @@ -150,7 +152,9 @@ class GammaTest(test.TestCase): # Mode will not be defined for the first entry. alpha_v = np.array([0.5, 3.0, 2.5]) beta_v = np.array([1.0, 4.0, 5.0]) - gamma = gamma_lib.Gamma(alpha=alpha_v, beta=beta_v, allow_nan_stats=True) + gamma = gamma_lib.Gamma(concentration=alpha_v, + rate=beta_v, + allow_nan_stats=True) expected_modes = (alpha_v - 1) / beta_v expected_modes[0] = np.nan self.assertEqual(gamma.mode().get_shape(), (3,)) @@ -160,7 +164,7 @@ class GammaTest(test.TestCase): with self.test_session(): alpha_v = np.array([1.0, 3.0, 2.5]) beta_v = np.array([1.0, 4.0, 5.0]) - gamma = gamma_lib.Gamma(alpha=alpha_v, beta=beta_v) + gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) expected_variances = stats.gamma.var(alpha_v, scale=1 / beta_v) self.assertEqual(gamma.variance().get_shape(), (3,)) self.assertAllClose(gamma.variance().eval(), expected_variances) @@ -169,7 +173,7 @@ class GammaTest(test.TestCase): with self.test_session(): alpha_v = np.array([1.0, 3.0, 2.5]) beta_v = np.array([1.0, 4.0, 5.0]) - gamma = gamma_lib.Gamma(alpha=alpha_v, beta=beta_v) + gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) expected_stddev = stats.gamma.std(alpha_v, scale=1. / beta_v) self.assertEqual(gamma.stddev().get_shape(), (3,)) self.assertAllClose(gamma.stddev().eval(), expected_stddev) @@ -179,7 +183,7 @@ class GammaTest(test.TestCase): alpha_v = np.array([1.0, 3.0, 2.5]) beta_v = np.array([1.0, 4.0, 5.0]) expected_entropy = stats.gamma.entropy(alpha_v, scale=1 / beta_v) - gamma = gamma_lib.Gamma(alpha=alpha_v, beta=beta_v) + gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) self.assertEqual(gamma.entropy().get_shape(), (3,)) self.assertAllClose(gamma.entropy().eval(), expected_entropy) @@ -190,7 +194,7 @@ class GammaTest(test.TestCase): alpha = constant_op.constant(alpha_v) beta = constant_op.constant(beta_v) n = 100000 - gamma = gamma_lib.Gamma(alpha=alpha, beta=beta) + gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) samples = gamma.sample(n, seed=137) sample_values = samples.eval() self.assertEqual(samples.get_shape(), (n,)) @@ -213,7 +217,7 @@ class GammaTest(test.TestCase): alpha = constant_op.constant(alpha_v) beta = constant_op.constant(beta_v) n = 100000 - gamma = gamma_lib.Gamma(alpha=alpha, beta=beta) + gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) samples = gamma.sample(n, seed=137) sample_values = samples.eval() self.assertEqual(samples.get_shape(), (n,)) @@ -233,7 +237,7 @@ class GammaTest(test.TestCase): with session.Session(): alpha_v = np.array([np.arange(1, 101, dtype=np.float32)]) # 1 x 100 beta_v = np.array([np.arange(1, 11, dtype=np.float32)]).T # 10 x 1 - gamma = gamma_lib.Gamma(alpha=alpha_v, beta=beta_v) + gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) n = 10000 samples = gamma.sample(n, seed=137) sample_values = samples.eval() @@ -268,7 +272,7 @@ class GammaTest(test.TestCase): def testGammaPdfOfSampleMultiDims(self): with session.Session() as sess: - gamma = gamma_lib.Gamma(alpha=[7., 11.], beta=[[5.], [6.]]) + gamma = gamma_lib.Gamma(concentration=[7., 11.], rate=[[5.], [6.]]) num = 50000 samples = gamma.sample(num, seed=137) pdfs = gamma.prob(samples) @@ -304,22 +308,29 @@ class GammaTest(test.TestCase): with self.test_session(): alpha_v = constant_op.constant(0.0, name="alpha") beta_v = constant_op.constant(1.0, name="beta") - gamma = gamma_lib.Gamma(alpha=alpha_v, beta=beta_v, validate_args=True) + gamma = gamma_lib.Gamma(concentration=alpha_v, + rate=beta_v, + validate_args=True) with self.assertRaisesOpError("alpha"): gamma.mean().eval() alpha_v = constant_op.constant(1.0, name="alpha") beta_v = constant_op.constant(0.0, name="beta") - gamma = gamma_lib.Gamma(alpha=alpha_v, beta=beta_v, validate_args=True) + gamma = gamma_lib.Gamma(concentration=alpha_v, + rate=beta_v, + validate_args=True) with self.assertRaisesOpError("beta"): gamma.mean().eval() - def testGammaWithSoftplusAlphaBeta(self): + def testGammaWithSoftplusConcentrationRate(self): with self.test_session(): alpha_v = constant_op.constant([0.0, -2.1], name="alpha") beta_v = constant_op.constant([1.0, -3.6], name="beta") - gamma = gamma_lib.GammaWithSoftplusAlphaBeta(alpha=alpha_v, beta=beta_v) - self.assertAllEqual(nn_ops.softplus(alpha_v).eval(), gamma.alpha.eval()) - self.assertAllEqual(nn_ops.softplus(beta_v).eval(), gamma.beta.eval()) + gamma = gamma_lib.GammaWithSoftplusConcentrationRate( + concentration=alpha_v, rate=beta_v) + self.assertAllEqual(nn_ops.softplus(alpha_v).eval(), + gamma.concentration.eval()) + self.assertAllEqual(nn_ops.softplus(beta_v).eval(), + gamma.rate.eval()) def testGammaGammaKL(self): alpha0 = np.array([3.]) @@ -330,8 +341,8 @@ class GammaTest(test.TestCase): # Build graph. with self.test_session() as sess: - g0 = gamma_lib.Gamma(alpha=alpha0, beta=beta0) - g1 = gamma_lib.Gamma(alpha=alpha1, beta=beta1) + g0 = gamma_lib.Gamma(concentration=alpha0, rate=beta0) + g1 = gamma_lib.Gamma(concentration=alpha1, rate=beta1) x = g0.sample(int(1e4), seed=0) kl_sample = math_ops.reduce_mean(g0.log_prob(x) - g1.log_prob(x), 0) kl_actual = kullback_leibler.kl(g0, g1) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/inverse_gamma_test.py b/tensorflow/contrib/distributions/python/kernel_tests/inverse_gamma_test.py index 9bcca23966..6eb96ea9ff 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/inverse_gamma_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/inverse_gamma_test.py @@ -33,7 +33,7 @@ class InverseGammaTest(test.TestCase): with self.test_session(): alpha = constant_op.constant([3.0] * 5) beta = constant_op.constant(11.0) - inv_gamma = inverse_gamma.InverseGamma(alpha=alpha, beta=beta) + inv_gamma = inverse_gamma.InverseGamma(concentration=alpha, rate=beta) self.assertEqual(inv_gamma.batch_shape_tensor().eval(), (5,)) self.assertEqual(inv_gamma.batch_shape, @@ -50,7 +50,7 @@ class InverseGammaTest(test.TestCase): alpha_v = 2.0 beta_v = 3.0 x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) - inv_gamma = inverse_gamma.InverseGamma(alpha=alpha, beta=beta) + inv_gamma = inverse_gamma.InverseGamma(concentration=alpha, rate=beta) expected_log_pdf = stats.invgamma.logpdf(x, alpha_v, scale=beta_v) log_pdf = inv_gamma.log_prob(x) self.assertEqual(log_pdf.get_shape(), (6,)) @@ -68,7 +68,7 @@ class InverseGammaTest(test.TestCase): alpha_v = np.array([2.0, 4.0]) beta_v = np.array([3.0, 4.0]) x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T - inv_gamma = inverse_gamma.InverseGamma(alpha=alpha, beta=beta) + inv_gamma = inverse_gamma.InverseGamma(concentration=alpha, rate=beta) expected_log_pdf = stats.invgamma.logpdf(x, alpha_v, scale=beta_v) log_pdf = inv_gamma.log_prob(x) log_pdf_values = log_pdf.eval() @@ -88,7 +88,7 @@ class InverseGammaTest(test.TestCase): alpha_v = np.array([2.0, 4.0]) beta_v = 3.0 x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T - inv_gamma = inverse_gamma.InverseGamma(alpha=alpha, beta=beta) + inv_gamma = inverse_gamma.InverseGamma(concentration=alpha, rate=beta) expected_log_pdf = stats.invgamma.logpdf(x, alpha_v, scale=beta_v) log_pdf = inv_gamma.log_prob(x) log_pdf_values = log_pdf.eval() @@ -109,7 +109,7 @@ class InverseGammaTest(test.TestCase): beta = constant_op.constant([beta_v] * batch_size) x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) - inv_gamma = inverse_gamma.InverseGamma(alpha=alpha, beta=beta) + inv_gamma = inverse_gamma.InverseGamma(concentration=alpha, rate=beta) expected_cdf = stats.invgamma.cdf(x, alpha_v, scale=beta_v) cdf = inv_gamma.cdf(x) @@ -120,7 +120,7 @@ class InverseGammaTest(test.TestCase): with self.test_session(): alpha_v = np.array([5.5, 3.0, 2.5]) beta_v = np.array([1.0, 4.0, 5.0]) - inv_gamma = inverse_gamma.InverseGamma(alpha=alpha_v, beta=beta_v) + inv_gamma = inverse_gamma.InverseGamma(concentration=alpha_v, rate=beta_v) expected_modes = beta_v / (alpha_v + 1) self.assertEqual(inv_gamma.mode().get_shape(), (3,)) self.assertAllClose(inv_gamma.mode().eval(), expected_modes) @@ -129,7 +129,7 @@ class InverseGammaTest(test.TestCase): with self.test_session(): alpha_v = np.array([5.5, 3.0, 2.5]) beta_v = np.array([1.0, 4.0, 5.0]) - inv_gamma = inverse_gamma.InverseGamma(alpha=alpha_v, beta=beta_v) + inv_gamma = inverse_gamma.InverseGamma(concentration=alpha_v, rate=beta_v) expected_means = stats.invgamma.mean(alpha_v, scale=beta_v) self.assertEqual(inv_gamma.mean().get_shape(), (3,)) self.assertAllClose(inv_gamma.mean().eval(), expected_means) @@ -140,7 +140,7 @@ class InverseGammaTest(test.TestCase): alpha_v = np.array([1.0, 3.0, 2.5]) beta_v = np.array([1.0, 4.0, 5.0]) inv_gamma = inverse_gamma.InverseGamma( - alpha=alpha_v, beta=beta_v, allow_nan_stats=False) + concentration=alpha_v, rate=beta_v, allow_nan_stats=False) with self.assertRaisesOpError("x < y"): inv_gamma.mean().eval() @@ -150,7 +150,7 @@ class InverseGammaTest(test.TestCase): alpha_v = np.array([0.5, 1.0, 3.0, 2.5]) beta_v = np.array([1.0, 2.0, 4.0, 5.0]) inv_gamma = inverse_gamma.InverseGamma( - alpha=alpha_v, beta=beta_v, allow_nan_stats=True) + concentration=alpha_v, rate=beta_v, allow_nan_stats=True) expected_means = beta_v / (alpha_v - 1) expected_means[0] = np.nan expected_means[1] = np.nan @@ -161,7 +161,7 @@ class InverseGammaTest(test.TestCase): with self.test_session(): alpha_v = np.array([7.0, 3.0, 2.5]) beta_v = np.array([1.0, 4.0, 5.0]) - inv_gamma = inverse_gamma.InverseGamma(alpha=alpha_v, beta=beta_v) + inv_gamma = inverse_gamma.InverseGamma(concentration=alpha_v, rate=beta_v) expected_variances = stats.invgamma.var(alpha_v, scale=beta_v) self.assertEqual(inv_gamma.variance().get_shape(), (3,)) self.assertAllClose(inv_gamma.variance().eval(), expected_variances) @@ -171,7 +171,7 @@ class InverseGammaTest(test.TestCase): alpha_v = np.array([1.5, 3.0, 2.5]) beta_v = np.array([1.0, 4.0, 5.0]) inv_gamma = inverse_gamma.InverseGamma( - alpha=alpha_v, beta=beta_v, allow_nan_stats=False) + concentration=alpha_v, rate=beta_v, allow_nan_stats=False) with self.assertRaisesOpError("x < y"): inv_gamma.variance().eval() @@ -180,7 +180,7 @@ class InverseGammaTest(test.TestCase): alpha_v = np.array([1.5, 3.0, 2.5]) beta_v = np.array([1.0, 4.0, 5.0]) inv_gamma = inverse_gamma.InverseGamma( - alpha=alpha_v, beta=beta_v, allow_nan_stats=True) + concentration=alpha_v, rate=beta_v, allow_nan_stats=True) expected_variances = stats.invgamma.var(alpha_v, scale=beta_v) expected_variances[0] = np.nan self.assertEqual(inv_gamma.variance().get_shape(), (3,)) @@ -191,7 +191,7 @@ class InverseGammaTest(test.TestCase): alpha_v = np.array([1.0, 3.0, 2.5]) beta_v = np.array([1.0, 4.0, 5.0]) expected_entropy = stats.invgamma.entropy(alpha_v, scale=beta_v) - inv_gamma = inverse_gamma.InverseGamma(alpha=alpha_v, beta=beta_v) + inv_gamma = inverse_gamma.InverseGamma(concentration=alpha_v, rate=beta_v) self.assertEqual(inv_gamma.entropy().get_shape(), (3,)) self.assertAllClose(inv_gamma.entropy().eval(), expected_entropy) @@ -202,7 +202,7 @@ class InverseGammaTest(test.TestCase): alpha = constant_op.constant(alpha_v) beta = constant_op.constant(beta_v) n = 100000 - inv_gamma = inverse_gamma.InverseGamma(alpha=alpha, beta=beta) + inv_gamma = inverse_gamma.InverseGamma(concentration=alpha, rate=beta) samples = inv_gamma.sample(n, seed=137) sample_values = samples.eval() self.assertEqual(samples.get_shape(), (n,)) @@ -222,7 +222,7 @@ class InverseGammaTest(test.TestCase): with session.Session(): alpha_v = np.array([np.arange(3, 103, dtype=np.float32)]) # 1 x 100 beta_v = np.array([np.arange(1, 11, dtype=np.float32)]).T # 10 x 1 - inv_gamma = inverse_gamma.InverseGamma(alpha=alpha_v, beta=beta_v) + inv_gamma = inverse_gamma.InverseGamma(concentration=alpha_v, rate=beta_v) n = 10000 samples = inv_gamma.sample(n, seed=137) sample_values = samples.eval() @@ -257,7 +257,9 @@ class InverseGammaTest(test.TestCase): def testInverseGammaPdfOfSampleMultiDims(self): with session.Session() as sess: - inv_gamma = inverse_gamma.InverseGamma(alpha=[7., 11.], beta=[[5.], [6.]]) + inv_gamma = inverse_gamma.InverseGamma( + concentration=[7., 11.], + rate=[[5.], [6.]]) num = 50000 samples = inv_gamma.sample(num, seed=137) pdfs = inv_gamma.prob(samples) @@ -294,24 +296,26 @@ class InverseGammaTest(test.TestCase): alpha_v = constant_op.constant(0.0, name="alpha") beta_v = constant_op.constant(1.0, name="beta") inv_gamma = inverse_gamma.InverseGamma( - alpha=alpha_v, beta=beta_v, validate_args=True) + concentration=alpha_v, rate=beta_v, validate_args=True) with self.assertRaisesOpError("alpha"): inv_gamma.mean().eval() alpha_v = constant_op.constant(1.0, name="alpha") beta_v = constant_op.constant(0.0, name="beta") inv_gamma = inverse_gamma.InverseGamma( - alpha=alpha_v, beta=beta_v, validate_args=True) + concentration=alpha_v, rate=beta_v, validate_args=True) with self.assertRaisesOpError("beta"): inv_gamma.mean().eval() - def testInverseGammaWithSoftplusAlphaBeta(self): + def testInverseGammaWithSoftplusConcentrationRate(self): with self.test_session(): alpha = constant_op.constant([-0.1, -2.9], name="alpha") beta = constant_op.constant([1.0, -4.8], name="beta") - inv_gamma = inverse_gamma.InverseGammaWithSoftplusAlphaBeta( - alpha=alpha, beta=beta, validate_args=True) - self.assertAllClose(nn_ops.softplus(alpha).eval(), inv_gamma.alpha.eval()) - self.assertAllClose(nn_ops.softplus(beta).eval(), inv_gamma.beta.eval()) + inv_gamma = inverse_gamma.InverseGammaWithSoftplusConcentrationRate( + concentration=alpha, rate=beta, validate_args=True) + self.assertAllClose(nn_ops.softplus(alpha).eval(), + inv_gamma.concentration.eval()) + self.assertAllClose(nn_ops.softplus(beta).eval(), + inv_gamma.rate.eval()) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py b/tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py index c88af121a6..0adaf7d816 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py @@ -30,7 +30,7 @@ class PoissonTest(test.TestCase): def testPoissonShape(self): with self.test_session(): lam = constant_op.constant([3.0] * 5) - poisson = poisson_lib.Poisson(lam=lam) + poisson = poisson_lib.Poisson(rate=lam) self.assertEqual(poisson.batch_shape_tensor().eval(), (5,)) self.assertEqual(poisson.batch_shape, tensor_shape.TensorShape([5])) @@ -38,16 +38,12 @@ class PoissonTest(test.TestCase): self.assertEqual(poisson.event_shape, tensor_shape.TensorShape([])) def testInvalidLam(self): - invalid_lams = [ - -.01, - 0, - -2., - ] + invalid_lams = [-.01, 0, -2.] for lam in invalid_lams: with self.test_session(): with self.assertRaisesOpError("Condition x > 0"): - poisson = poisson_lib.Poisson(lam=lam, validate_args=True) - poisson.lam.eval() + poisson = poisson_lib.Poisson(rate=lam, validate_args=True) + poisson.rate.eval() def testPoissonLogPmf(self): with self.test_session(): @@ -55,7 +51,7 @@ class PoissonTest(test.TestCase): lam = constant_op.constant([3.0] * batch_size) lam_v = 3.0 x = [2., 3., 4., 5., 6., 7.] - poisson = poisson_lib.Poisson(lam=lam) + poisson = poisson_lib.Poisson(rate=lam) log_pmf = poisson.log_prob(x) self.assertEqual(log_pmf.get_shape(), (6,)) self.assertAllClose(log_pmf.eval(), stats.poisson.logpmf(x, lam_v)) @@ -69,7 +65,7 @@ class PoissonTest(test.TestCase): batch_size = 6 lam = constant_op.constant([3.0] * batch_size) x = [2.5, 3.2, 4.3, 5.1, 6., 7.] - poisson = poisson_lib.Poisson(lam=lam, validate_args=True) + poisson = poisson_lib.Poisson(rate=lam, validate_args=True) # Non-integer with self.assertRaisesOpError("x has non-integer components"): @@ -80,7 +76,7 @@ class PoissonTest(test.TestCase): log_pmf = poisson.log_prob([-1.]) log_pmf.eval() - poisson = poisson_lib.Poisson(lam=lam, validate_args=False) + poisson = poisson_lib.Poisson(rate=lam, validate_args=False) log_pmf = poisson.log_prob(x) self.assertEqual(log_pmf.get_shape(), (6,)) pmf = poisson.prob(x) @@ -93,7 +89,7 @@ class PoissonTest(test.TestCase): lam_v = [2.0, 4.0, 5.0] x = np.array([[2., 3., 4., 5., 6., 7.]], dtype=np.float32).T - poisson = poisson_lib.Poisson(lam=lam) + poisson = poisson_lib.Poisson(rate=lam) log_pmf = poisson.log_prob(x) self.assertEqual(log_pmf.get_shape(), (6, 3)) self.assertAllClose(log_pmf.eval(), stats.poisson.logpmf(x, lam_v)) @@ -109,7 +105,7 @@ class PoissonTest(test.TestCase): lam_v = 3.0 x = [2.2, 3.1, 4., 5.5, 6., 7.] - poisson = poisson_lib.Poisson(lam=lam) + poisson = poisson_lib.Poisson(rate=lam) log_cdf = poisson.log_cdf(x) self.assertEqual(log_cdf.get_shape(), (6,)) self.assertAllClose(log_cdf.eval(), stats.poisson.logcdf(x, lam_v)) @@ -125,7 +121,7 @@ class PoissonTest(test.TestCase): lam_v = [2.0, 4.0, 5.0] x = np.array([[2.2, 3.1, 4., 5.5, 6., 7.]], dtype=np.float32).T - poisson = poisson_lib.Poisson(lam=lam) + poisson = poisson_lib.Poisson(rate=lam) log_cdf = poisson.log_cdf(x) self.assertEqual(log_cdf.get_shape(), (6, 3)) self.assertAllClose(log_cdf.eval(), stats.poisson.logcdf(x, lam_v)) @@ -137,7 +133,7 @@ class PoissonTest(test.TestCase): def testPoissonMean(self): with self.test_session(): lam_v = [1.0, 3.0, 2.5] - poisson = poisson_lib.Poisson(lam=lam_v) + poisson = poisson_lib.Poisson(rate=lam_v) self.assertEqual(poisson.mean().get_shape(), (3,)) self.assertAllClose(poisson.mean().eval(), stats.poisson.mean(lam_v)) self.assertAllClose(poisson.mean().eval(), lam_v) @@ -145,7 +141,7 @@ class PoissonTest(test.TestCase): def testPoissonVariance(self): with self.test_session(): lam_v = [1.0, 3.0, 2.5] - poisson = poisson_lib.Poisson(lam=lam_v) + poisson = poisson_lib.Poisson(rate=lam_v) self.assertEqual(poisson.variance().get_shape(), (3,)) self.assertAllClose(poisson.variance().eval(), stats.poisson.var(lam_v)) self.assertAllClose(poisson.variance().eval(), lam_v) @@ -153,7 +149,7 @@ class PoissonTest(test.TestCase): def testPoissonStd(self): with self.test_session(): lam_v = [1.0, 3.0, 2.5] - poisson = poisson_lib.Poisson(lam=lam_v) + poisson = poisson_lib.Poisson(rate=lam_v) self.assertEqual(poisson.stddev().get_shape(), (3,)) self.assertAllClose(poisson.stddev().eval(), stats.poisson.std(lam_v)) self.assertAllClose(poisson.stddev().eval(), np.sqrt(lam_v)) @@ -161,14 +157,14 @@ class PoissonTest(test.TestCase): def testPoissonMode(self): with self.test_session(): lam_v = [1.0, 3.0, 2.5, 3.2, 1.1, 0.05] - poisson = poisson_lib.Poisson(lam=lam_v) + poisson = poisson_lib.Poisson(rate=lam_v) self.assertEqual(poisson.mode().get_shape(), (6,)) self.assertAllClose(poisson.mode().eval(), np.floor(lam_v)) def testPoissonMultipleMode(self): with self.test_session(): lam_v = [1.0, 3.0, 2.0, 4.0, 5.0, 10.0] - poisson = poisson_lib.Poisson(lam=lam_v) + poisson = poisson_lib.Poisson(rate=lam_v) # For the case where lam is an integer, the modes are: lam and lam - 1. # In this case, we get back the larger of the two modes. self.assertEqual((6,), poisson.mode().get_shape()) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/quantized_distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/quantized_distribution_test.py index de7141e660..0e2d143732 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/quantized_distribution_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/quantized_distribution_test.py @@ -201,7 +201,7 @@ class QuantizedDistributionTest(test.TestCase): # integers. Hence, F(X) (see below) will not be uniform exactly. with self.test_session(): qdist = distributions.QuantizedDistribution( - distribution=distributions.Exponential(lam=0.01)) + distribution=distributions.Exponential(rate=0.01)) # X ~ QuantizedExponential x = qdist.sample(10000, seed=42) # Z = F(X), should be Uniform. @@ -224,7 +224,7 @@ class QuantizedDistributionTest(test.TestCase): # Make an exponential with mean 5. with self.test_session(): qdist = distributions.QuantizedDistribution( - distribution=distributions.Exponential(lam=0.2)) + distribution=distributions.Exponential(rate=0.2)) # Standard error should be less than 1 / (2 * sqrt(n_samples)) n_samples = 10000 stddev_err_bound = 1 / (2 * np.sqrt(n_samples)) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py b/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py index 754a896f75..1fa6ca0906 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py @@ -67,10 +67,13 @@ class WishartCholeskyTest(test.TestCase): with self.test_session(): def entropy_alt(w): - return (w.log_normalizing_constant() - 0.5 * (w.df - w.dimension - 1.) * - w.mean_log_det() + 0.5 * w.df * w.dimension).eval() + return ( + w.log_normalization() + - 0.5 * (w.df - w.dimension - 1.) * w.mean_log_det() + + 0.5 * w.df * w.dimension).eval() - w = distributions.WishartCholesky(df=4, scale=chol(make_pd(1., 2))) + w = distributions.WishartCholesky(df=4, + scale=chol(make_pd(1., 2))) self.assertAllClose(w.entropy().eval(), entropy_alt(w)) w = distributions.WishartCholesky(df=5, scale=[[1.]]) @@ -204,7 +207,9 @@ class WishartCholeskyTest(test.TestCase): # This test checks that batches don't interfere with correctness. w = distributions.WishartCholesky( - df=[2, 3, 4, 5], scale=chol_x, cholesky_input_output_matrices=True) + df=[2, 3, 4, 5], + scale=chol_x, + cholesky_input_output_matrices=True) self.assertAllClose(log_prob_df_seq, w.log_prob(chol_x).eval()) # Now we test various constructions of Wishart with different sample @@ -221,10 +226,15 @@ class WishartCholeskyTest(test.TestCase): -20.951582705289454, ]) - for w in (distributions.WishartCholesky( - df=4, scale=chol_x[0], cholesky_input_output_matrices=False), - distributions.WishartFull( - df=4, scale=x[0], cholesky_input_output_matrices=False)): + for w in ( + distributions.WishartCholesky( + df=4, + scale=chol_x[0], + cholesky_input_output_matrices=False), + distributions.WishartFull( + df=4, + scale=x[0], + cholesky_input_output_matrices=False)): self.assertAllEqual((2, 2), w.event_shape_tensor().eval()) self.assertEqual(2, w.dimension.eval()) self.assertAllClose(log_prob[0], w.log_prob(x[0]).eval()) @@ -238,10 +248,15 @@ class WishartCholeskyTest(test.TestCase): self.assertAllEqual((2, 2), w.log_prob(np.reshape(x, (2, 2, 2, 2))).get_shape()) - for w in (distributions.WishartCholesky( - df=4, scale=chol_x[0], cholesky_input_output_matrices=True), - distributions.WishartFull( - df=4, scale=x[0], cholesky_input_output_matrices=True)): + for w in ( + distributions.WishartCholesky( + df=4, + scale=chol_x[0], + cholesky_input_output_matrices=True), + distributions.WishartFull( + df=4, + scale=x[0], + cholesky_input_output_matrices=True)): self.assertAllEqual((2, 2), w.event_shape_tensor().eval()) self.assertEqual(2, w.dimension.eval()) self.assertAllClose(log_prob[0], w.log_prob(chol_x[0]).eval()) @@ -315,7 +330,9 @@ class WishartCholeskyTest(test.TestCase): with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, "cannot be less than"): chol_w = distributions.WishartCholesky( - df=df_deferred, scale=chol_scale_deferred, validate_args=True) + df=df_deferred, + scale=chol_scale_deferred, + validate_args=True) sess.run(chol_w.log_prob(np.asarray( x, dtype=np.float32)), feed_dict={df_deferred: 2., @@ -336,7 +353,9 @@ class WishartCholeskyTest(test.TestCase): # Ensure no assertions. chol_w = distributions.WishartCholesky( - df=df_deferred, scale=chol_scale_deferred, validate_args=False) + df=df_deferred, + scale=chol_scale_deferred, + validate_args=False) sess.run(chol_w.log_prob(np.asarray( x, dtype=np.float32)), feed_dict={df_deferred: 4, diff --git a/tensorflow/contrib/distributions/python/ops/beta.py b/tensorflow/contrib/distributions/python/ops/beta.py index baab0503f5..b4e9bc5b2e 100644 --- a/tensorflow/contrib/distributions/python/ops/beta.py +++ b/tensorflow/contrib/distributions/python/ops/beta.py @@ -36,158 +36,170 @@ from tensorflow.python.ops import nn from tensorflow.python.ops import random_ops -_beta_prob_note = """ - Note that the argument `x` must be a non-negative floating point tensor - whose shape can be broadcast with `self.a` and `self.b`. For fixed leading - dimensions, the last dimension represents counts for the corresponding Beta - distribution in `self.a` and `self.b`. `x` is only legal if `0 < x < 1`. -""" +__all__ = [ + "Beta", + "BetaWithSoftplusConcentration", +] -class Beta(distribution.Distribution): - """Beta distribution. +_beta_sample_note = """Note: `x` must have dtype `self.dtype` and be in +`[0, 1].` It must have a shape compatible with `self.batch_shape()`.""" - This distribution is parameterized by `a` and `b` which are shape - parameters. - #### Mathematical details +class Beta(distribution.Distribution): + """Beta distribution. - The Beta is a distribution over the interval (0, 1). - The distribution has hyperparameters `a` and `b` and - probability mass function (pdf): + The Beta distribution is defined over the `(0, 1)` interval using parameters + `concentration1` (aka "alpha") and `concentration0` (aka "beta"). - ```pdf(x) = 1 / Beta(a, b) * x^(a - 1) * (1 - x)^(b - 1)``` + #### Mathematical Details - where `Beta(a, b) = Gamma(a) * Gamma(b) / Gamma(a + b)` - is the beta function. + The probability density function (pdf) is, + ```none + pdf(x; alpha, beta) = x**(alpha - 1) (1 - x)**(beta - 1) / Z + Z = Gamma(alpha) Gamma(beta) / Gamma(alpha + beta) + ``` - This class provides methods to create indexed batches of Beta - distributions. One entry of the broadcasted - shape represents of `a` and `b` represents one single Beta distribution. - When calling distribution functions (e.g. `dist.prob(x)`), `a`, `b` - and `x` are broadcast to the same shape (if possible). - Every entry in a/b/x corresponds to a single Beta distribution. + where: - #### Examples + * `concentration1 = alpha`, + * `concentration0 = beta`, + * `Z` is the normalization constant, and, + * `Gamma` is the [gamma function]( + https://en.wikipedia.org/wiki/Gamma_function). - Creates 3 distributions. - The distribution functions can be evaluated on x. + The concentration parameters represent mean total counts of a `1` or a `0`, + i.e., - ```python - a = [1, 2, 3] - b = [1, 2, 3] - dist = Beta(a, b) + ```none + concentration1 = alpha = mean * total_concentration + concentration0 = beta = (1. - mean) * total_concentration ``` - ```python - # x same shape as a. - x = [.2, .3, .7] - dist.prob(x) # Shape [3] - - # a/b will be broadcast to [[1, 2, 3], [1, 2, 3]] to match x. - x = [[.1, .4, .5], [.2, .3, .5]] - dist.prob(x) # Shape [2, 3] + where `mean` in `(0, 1)` and `total_concentration` is a positive real number + representing a mean `total_count = concentration1 + concentration0`. - # a/b will be broadcast to shape [5, 7, 3] to match x. - x = [[...]] # Shape [5, 7, 3] - dist.prob(x) # Shape [5, 7, 3] - ``` + Distribution parameters are automatically broadcast in all functions; see + examples for details. - Creates a 2-batch of 3-class distributions. + #### Examples ```python - a = [[1, 2, 3], [4, 5, 6]] # Shape [2, 3] - b = 5 # Shape [] - dist = Beta(a, b) + # Create a batch of three Beta distributions. + alpha = [1, 2, 3] + beta = [1, 2, 3] + dist = Beta(alpha, beta) + + dist.sample([4, 5]) # Shape [4, 5, 3] + + # `x` has three batch entries, each with two samples. + x = [[.1, .4, .5], + [.2, .3, .5]] + # Calculate the probability of each pair of samples under the corresponding + # distribution in `dist`. + dist.prob(x) # Shape [2, 3] + ``` - # x will be broadcast to [[.2, .3, .9], [.2, .3, .9]] to match a/b. - x = [.2, .3, .9] - dist.prob(x) # Shape [2] + ```python + # Create batch_shape=[2, 3] via parameter broadcast: + alpha = [[1.], [2]] # Shape [2, 1] + beta = [3., 4, 5] # Shape [3] + dist = Beta(alpha, beta) + + # alpha broadcast as: [[1., 1, 1,], + # [2, 2, 2]] + # beta broadcast as: [[3., 4, 5], + # [3, 4, 5]] + # batch_Shape [2, 3] + dist.sample([4, 5]) # Shape [4, 5, 2, 3] + + x = [.2, .3, .5] + # x will be broadcast as [[.2, .3, .5], + # [.2, .3, .5]], + # thus matching batch_shape [2, 3]. + dist.prob(x) # Shape [2, 3] ``` """ - def __init__(self, a, b, validate_args=False, allow_nan_stats=True, + def __init__(self, + concentration1=None, + concentration0=None, + validate_args=False, + allow_nan_stats=True, name="Beta"): """Initialize a batch of Beta distributions. Args: - a: Positive floating point tensor with shape broadcastable to - `[N1,..., Nm]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm` - different Beta distributions. This also defines the - dtype of the distribution. - b: Positive floating point tensor with shape broadcastable to - `[N1,..., Nm]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm` - different Beta distributions. - validate_args: `Boolean`, default `False`. Whether to assert valid - values for parameters `a`, `b`, and `x` in `prob` and `log_prob`. - If `False` and inputs are invalid, correct behavior is not guaranteed. - allow_nan_stats: `Boolean`, default `True`. If `False`, raise an - exception if a statistic (e.g. mean/mode/etc...) is undefined for any - batch member. If `True`, batch members with valid parameters leading to - undefined statistics will return NaN for this statistic. - name: The name to prefix Ops created by this distribution class. - - Examples: - - ```python - # Define 1-batch. - dist = Beta(1.1, 2.0) - - # Define a 2-batch. - dist = Beta([1.0, 2.0], [4.0, 5.0]) - ``` - + concentration1: Positive floating-point `Tensor` indicating mean + number of successes; aka "alpha". Implies `self.dtype` and + `self.batch_shape`, i.e., + `concentration1.shape = [N1, N2, ..., Nm] = self.batch_shape`. + concentration0: Positive floating-point `Tensor` indicating mean + number of failures; aka "beta". Otherwise has same semantics as + `concentration1`. + validate_args: Python `Boolean`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + allow_nan_stats: Python `Boolean`, default `True`. When `True`, statistics + (e.g., mean, mode, variance) use the value "`NaN`" to indicate the + result is undefined. When `False`, an exception is raised if one or + more of the statistic's batch members are undefined. + name: `String` name prefixed to Ops created by this class. """ parameters = locals() parameters.pop("self") - with ops.name_scope(name, values=[a, b]) as ns: - with ops.control_dependencies([ - check_ops.assert_positive(a), - check_ops.assert_positive(b), - ] if validate_args else []): - self._a = array_ops.identity(a, name="a") - self._b = array_ops.identity(b, name="b") - contrib_tensor_util.assert_same_float_dtype((self._a, self._b)) - # Used for mean/mode/variance/entropy/sampling computations - self._a_b_sum = self._a + self._b + with ops.name_scope(name, values=[concentration1, + concentration0]) as ns: + self._concentration1 = self._maybe_assert_valid_concentration( + ops.convert_to_tensor(concentration1, name="concentration1"), + validate_args) + self._concentration0 = self._maybe_assert_valid_concentration( + ops.convert_to_tensor(concentration0, name="concentration0"), + validate_args) + contrib_tensor_util.assert_same_float_dtype([ + self._concentration1, self._concentration0]) + self._total_concentration = self._concentration1 + self._concentration0 super(Beta, self).__init__( - dtype=self._a_b_sum.dtype, + dtype=self._total_concentration.dtype, validate_args=validate_args, allow_nan_stats=allow_nan_stats, is_continuous=True, reparameterization_type=distribution.NOT_REPARAMETERIZED, parameters=parameters, - graph_parents=[self._a, self._b, self._a_b_sum], + graph_parents=[self._concentration1, + self._concentration0, + self._total_concentration], name=ns) @staticmethod def _param_shapes(sample_shape): - return dict( - zip(("a", "b"), ([ops.convert_to_tensor( - sample_shape, dtype=dtypes.int32)] * 2))) + return dict(zip( + ["concentration1", "concentration0"], + [ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)] * 2)) @property - def a(self): - """Shape parameter.""" - return self._a + def concentration1(self): + """Concentration parameter associated with a `1` outcome.""" + return self._concentration1 @property - def b(self): - """Shape parameter.""" - return self._b + def concentration0(self): + """Concentration parameter associated with a `0` outcome.""" + return self._concentration0 @property - def a_b_sum(self): - """Sum of parameters.""" - return self._a_b_sum + def total_concentration(self): + """Sum of concentration parameters.""" + return self._total_concentration def _batch_shape_tensor(self): - return array_ops.shape(self.a_b_sum) + return array_ops.shape(self.total_concentration) def _batch_shape(self): - return self.a_b_sum.get_shape() + return self.total_concentration.get_shape() def _event_shape_tensor(self): return constant_op.constant([], dtype=dtypes.int32) @@ -196,103 +208,130 @@ class Beta(distribution.Distribution): return tensor_shape.scalar() def _sample_n(self, n, seed=None): - a = array_ops.ones_like(self.a_b_sum, dtype=self.dtype) * self.a - b = array_ops.ones_like(self.a_b_sum, dtype=self.dtype) * self.b + expanded_concentration1 = array_ops.ones_like( + self.total_concentration, dtype=self.dtype) * self.concentration1 + expanded_concentration0 = array_ops.ones_like( + self.total_concentration, dtype=self.dtype) * self.concentration0 gamma1_sample = random_ops.random_gamma( - [n,], a, dtype=self.dtype, seed=seed) + shape=[n], + alpha=expanded_concentration1, + dtype=self.dtype, + seed=seed) gamma2_sample = random_ops.random_gamma( - [n,], b, dtype=self.dtype, + shape=[n], + alpha=expanded_concentration0, + dtype=self.dtype, seed=distribution_util.gen_new_seed(seed, "beta")) beta_sample = gamma1_sample / (gamma1_sample + gamma2_sample) return beta_sample + @distribution_util.AppendDocstring(_beta_sample_note) def _log_prob(self, x): - x = self._assert_valid_sample(x) - log_unnormalized_prob = ((self.a - 1.) * math_ops.log(x) + - (self.b - 1.) * math_ops.log(1. - x)) - log_normalization = (math_ops.lgamma(self.a) + - math_ops.lgamma(self.b) - - math_ops.lgamma(self.a_b_sum)) - return log_unnormalized_prob - log_normalization - - @distribution_util.AppendDocstring(_beta_prob_note) + return self._log_unnormalized_prob(x) - self._log_normalization() + + @distribution_util.AppendDocstring(_beta_sample_note) def _prob(self, x): return math_ops.exp(self._log_prob(x)) - @distribution_util.AppendDocstring(_beta_prob_note) + @distribution_util.AppendDocstring(_beta_sample_note) def _log_cdf(self, x): return math_ops.log(self._cdf(x)) + @distribution_util.AppendDocstring(_beta_sample_note) def _cdf(self, x): - return math_ops.betainc(self.a, self.b, x) + return math_ops.betainc(self.concentration1, self.concentration0, x) + + def _log_unnormalized_prob(self, x): + x = self._maybe_assert_valid_sample(x) + return ((self.concentration1 - 1.) * math_ops.log(x) + + (self.concentration0 - 1.) * math_ops.log1p(-x)) + + def _log_normalization(self): + return (math_ops.lgamma(self.concentration1) + + math_ops.lgamma(self.concentration0) + - math_ops.lgamma(self.total_concentration)) def _entropy(self): - return (math_ops.lgamma(self.a) - - (self.a - 1.) * math_ops.digamma(self.a) + - math_ops.lgamma(self.b) - - (self.b - 1.) * math_ops.digamma(self.b) - - math_ops.lgamma(self.a_b_sum) + - (self.a_b_sum - 2.) * math_ops.digamma(self.a_b_sum)) + return ( + self._log_normalization() + - (self.concentration1 - 1.) * math_ops.digamma(self.concentration1) + - (self.concentration0 - 1.) * math_ops.digamma(self.concentration0) + + ((self.total_concentration - 2.) * + math_ops.digamma(self.total_concentration))) def _mean(self): - return self.a / self.a_b_sum + return self._concentration1 / self._total_concentration def _variance(self): - return (self.a * self.b) / (self.a_b_sum**2. * (self.a_b_sum + 1.)) + return self._mean() * (1. - self._mean()) / (1. + self.total_concentration) @distribution_util.AppendDocstring( - """Note that the mode for the Beta distribution is only defined - when `a > 1`, `b > 1`. This returns the mode when `a > 1` and `b > 1`, - and `NaN` otherwise. If `self.allow_nan_stats` is `False`, an exception - will be raised rather than returning `NaN`.""") + """Note: The mode is undefined when `concentration1 <= 1` or + `concentration0 <= 1`. If `self.allow_nan_stats` is `True`, `NaN` + is used for undefined modes. If `self.allow_nan_stats` is `False` an + exception is raised when one or more modes are undefined.""") def _mode(self): - mode = (self.a - 1.)/ (self.a_b_sum - 2.) + mode = (self.concentration1 - 1.) / (self.total_concentration - 2.) if self.allow_nan_stats: - nan = np.array(np.nan, dtype=self.dtype.as_numpy_dtype()) - return array_ops.where( - math_ops.logical_and( - math_ops.greater(self.a, 1.), - math_ops.greater(self.b, 1.)), - mode, - array_ops.fill(self.batch_shape_tensor(), nan, name="nan")) - else: - return control_flow_ops.with_dependencies([ - check_ops.assert_less( - array_ops.ones((), dtype=self.dtype), self.a, - message="Mode not defined for components of a <= 1."), - check_ops.assert_less( - array_ops.ones((), dtype=self.dtype), self.b, - message="Mode not defined for components of b <= 1."), - ], mode) - - def _assert_valid_sample(self, x): - """Check x for proper shape, values, then return tensor version.""" - if not self.validate_args: return x + nan = array_ops.fill( + self.batch_shape_tensor(), + np.array(np.nan, dtype=self.dtype.as_numpy_dtype()), + name="nan") + is_defined = math_ops.logical_and(self.concentration1 > 1., + self.concentration0 > 1.) + return array_ops.where(is_defined, mode, nan) + return control_flow_ops.with_dependencies([ + check_ops.assert_less( + array_ops.ones([], dtype=self.dtype), + self.concentration1, + message="Mode undefined for concentration1 <= 1."), + check_ops.assert_less( + array_ops.ones([], dtype=self.dtype), + self.concentration0, + message="Mode undefined for concentration0 <= 1.") + ], mode) + + def _maybe_assert_valid_concentration(self, concentration, validate_args): + """Checks the validity of a concentration parameter.""" + if not validate_args: + return concentration + return control_flow_ops.with_dependencies([ + check_ops.assert_positive( + concentration, + message="Concentration parameter must be positive."), + ], concentration) + + def _maybe_assert_valid_sample(self, x): + """Checks the validity of a sample.""" + if not self.validate_args: + return x return control_flow_ops.with_dependencies([ check_ops.assert_positive( x, - message="Negative events lie outside Beta distribution support."), + message="sample must be positive"), check_ops.assert_less( - x, array_ops.ones((), self.dtype), - message="Event>=1 lies outside Beta distribution support."), + x, array_ops.ones([], self.dtype), + message="sample must be no larger than `1`."), ], x) -class BetaWithSoftplusAB(Beta): - """Beta with softplus transform on `a` and `b`.""" +class BetaWithSoftplusConcentration(Beta): + """Beta with softplus transform of `concentration1` and `concentration0`.""" def __init__(self, - a, - b, + concentration1, + concentration0, validate_args=False, allow_nan_stats=True, - name="BetaWithSoftplusAB"): + name="BetaWithSoftplusConcentration"): parameters = locals() - parameters.pop("self") - with ops.name_scope(name, values=[a, b]) as ns: - super(BetaWithSoftplusAB, self).__init__( - a=nn.softplus(a, name="softplus_a"), - b=nn.softplus(b, name="softplus_b"), + with ops.name_scope(name, values=[concentration1, + concentration0]) as ns: + super(BetaWithSoftplusConcentration, self).__init__( + concentration1=nn.softplus(concentration1, + name="softplus_concentration1"), + concentration0=nn.softplus(concentration0, + name="softplus_concentration0"), validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=ns) @@ -301,7 +340,7 @@ class BetaWithSoftplusAB(Beta): @kullback_leibler.RegisterKL(Beta, Beta) def _kl_beta_beta(d1, d2, name=None): - """Calculate the batched KL divergence KL(d1 || d2) with d1 and d2 Beta. + """Calculate the batchwise KL divergence KL(d1 || d2) with d1 and d2 Beta. Args: d1: instance of a Beta distribution object. @@ -312,14 +351,20 @@ def _kl_beta_beta(d1, d2, name=None): Returns: Batchwise KL(d1 || d2) """ - inputs = [d1.a, d1.b, d1.a_b_sum, d2.a_b_sum] - with ops.name_scope(name, "kl_beta_beta", inputs): - # ln(B(a', b') / B(a, b)) - log_betas = (math_ops.lgamma(d2.a) + math_ops.lgamma(d2.b) - - math_ops.lgamma(d2.a_b_sum) + math_ops.lgamma(d1.a_b_sum) - - math_ops.lgamma(d1.a) - math_ops.lgamma(d1.b)) - # (a - a')*psi(a) + (b - b')*psi(b) + (a' - a + b' - b)*psi(a + b) - digammas = ((d1.a - d2.a) * math_ops.digamma(d1.a) - + (d1.b - d2.b) * math_ops.digamma(d1.b) - + (d2.a_b_sum - d1.a_b_sum) * math_ops.digamma(d1.a_b_sum)) - return log_betas + digammas + def delta(fn, is_property=True): + fn1 = getattr(d1, fn) + fn2 = getattr(d2, fn) + return (fn2 - fn1) if is_property else (fn2() - fn1()) + with ops.name_scope(name, "kl_beta_beta", values=[ + d1.concentration1, + d1.concentration0, + d1.total_concentration, + d2.concentration1, + d2.concentration0, + d2.total_concentration, + ]): + return (delta("_log_normalization", is_property=False) + - math_ops.digamma(d1.concentration1) * delta("concentration1") + - math_ops.digamma(d1.concentration0) * delta("concentration0") + + (math_ops.digamma(d1.total_concentration) + * delta("total_concentration"))) diff --git a/tensorflow/contrib/distributions/python/ops/chi2.py b/tensorflow/contrib/distributions/python/ops/chi2.py index 210cbe95c8..870646e26f 100644 --- a/tensorflow/contrib/distributions/python/ops/chi2.py +++ b/tensorflow/contrib/distributions/python/ops/chi2.py @@ -25,15 +25,40 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import math_ops +__all__ = [ + "Chi2", + "Chi2WithAbsDf", +] + + class Chi2(gamma.Gamma): - """The Chi2 distribution with degrees of freedom df. + """Chi2 distribution. + + The Chi2 distribution is defined over positive real numbers using a degrees of + freedom ("df") parameter. + + #### Mathematical Details + + The probability density function (pdf) is, + + ```none + pdf(x; df, x > 0) = x**(0.5 df - 1) exp(-0.5 x) / Z + Z = 2**(0.5 df) Gamma(0.5 df) + ``` + + where: + + * `df` denotes the degrees of freedom, + * `Z` is the normalization constant, and, + * `Gamma` is the [gamma function]( + https://en.wikipedia.org/wiki/Gamma_function). - The PDF of this distribution is: + The Chi2 distribution is a special case of the Gamma distribution, i.e., - ```pdf(x) = (x^(df/2 - 1)e^(-x/2))/(2^(df/2)Gamma(df/2)), x > 0``` + ```python + Chi2(df) = Gamma(concentration=0.5 * df, rate=0.5) + ``` - Note that the Chi2 distribution is a special case of the Gamma distribution, - with Chi2(df) = Gamma(df/2, 1/2). """ def __init__(self, @@ -46,15 +71,15 @@ class Chi2(gamma.Gamma): Args: df: Floating point tensor, the degrees of freedom of the distribution(s). `df` must contain only positive values. - validate_args: `Boolean`, default `False`. Whether to assert that - `df > 0`, and that `x > 0` in the methods `prob(x)` and `log_prob(x)`. - If `validate_args` is `False` and the inputs are invalid, correct - behavior is not guaranteed. - allow_nan_stats: `Boolean`, default `True`. If `False`, raise an - exception if a statistic (e.g. mean/mode/etc...) is undefined for any - batch member. If `True`, batch members with valid parameters leading to - undefined statistics will return NaN for this statistic. - name: The name to prepend to all ops created by this distribution. + validate_args: Python `Boolean`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + allow_nan_stats: Python `Boolean`, default `True`. When `True`, statistics + (e.g., mean, mode, variance) use the value "`NaN`" to indicate the + result is undefined. When `False`, an exception is raised if one or + more of the statistic's batch members are undefined. + name: `String` name prefixed to Ops created by this class. """ parameters = locals() parameters.pop("self") @@ -65,8 +90,8 @@ class Chi2(gamma.Gamma): with ops.name_scope(name, values=[df]) as ns: self._df = ops.convert_to_tensor(df, name="df") super(Chi2, self).__init__( - alpha=0.5 * self._df, - beta=constant_op.constant(0.5, dtype=self._df.dtype), + concentration=0.5 * self._df, + rate=constant_op.constant(0.5, dtype=self._df.dtype), validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=ns) @@ -93,8 +118,9 @@ class Chi2WithAbsDf(Chi2): parameters.pop("self") with ops.name_scope(name, values=[df]) as ns: super(Chi2WithAbsDf, self).__init__( - df=math_ops.floor(math_ops.abs(df, name="abs_df"), - name="floor_abs_df"), + df=math_ops.floor( + math_ops.abs(df, name="abs_df"), + name="floor_abs_df"), validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=ns) diff --git a/tensorflow/contrib/distributions/python/ops/dirichlet.py b/tensorflow/contrib/distributions/python/ops/dirichlet.py index 2c39c73195..30d8d136ec 100644 --- a/tensorflow/contrib/distributions/python/ops/dirichlet.py +++ b/tensorflow/contrib/distributions/python/ops/dirichlet.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== """The Dirichlet distribution class.""" + from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -30,189 +31,202 @@ from tensorflow.python.ops import random_ops from tensorflow.python.ops import special_math_ops -_dirichlet_prob_note = """ -Note that the input must be a non-negative tensor with dtype `dtype` and whose -shape can be broadcast with `self.alpha`. For fixed leading dimensions, the -last dimension represents counts for the corresponding Dirichlet distribution -in `self.alpha`. `x` is only legal if it sums up to one. -""" +__all__ = [ + "Dirichlet", +] + + +_dirichlet_sample_note = """Note: `value` must be a non-negative tensor with +dtype `self.dtype` and be in the `(self.event_shape() - 1)`-simplex, i.e., +`tf.reduce_sum(value, -1) = 1`. It must have a shape compatible with +`self.batch_shape() + self.event_shape()`.""" class Dirichlet(distribution.Distribution): """Dirichlet distribution. - This distribution is parameterized by a vector `alpha` of concentration - parameters for `k` classes. + The Dirichlet distribution is defined over the + [`(k-1)`-simplex](https://en.wikipedia.org/wiki/Simplex) using a positive, + length-`k` vector `concentration` (`k > 1`). The Dirichlet is identically the + Beta distribution when `k = 2`. - #### Mathematical details + #### Mathematical Details - The Dirichlet is a distribution over the standard n-simplex, where the - standard n-simplex is defined by: - ```{ (x_1, ..., x_n) in R^(n+1) | sum_j x_j = 1 and x_j >= 0 for all j }```. - The distribution has hyperparameters `alpha = (alpha_1,...,alpha_k)`, - and probability mass function (prob): + The Dirichlet is a distribution over the open `(k-1)`-simplex, i.e., - ```prob(x) = 1 / Beta(alpha) * prod_j x_j^(alpha_j - 1)``` + ```none + S^{k-1} = { (x_0, ..., x_{k-1}) in R^k : sum_j x_j = 1 and all_j x_j > 0 }. + ``` + + The probability density function (pdf) is, - where `Beta(x) = prod_j Gamma(x_j) / Gamma(sum_j x_j)` is the multivariate - beta function. + ```none + pdf(x; alpha) = prod_j x_j**(alpha_j - 1) / Z + Z = prod_j Gamma(alpha_j) / Gamma(sum_j alpha_j) + ``` + + where: + + * `x in S^{k-1}`, i.e., the `(k-1)`-simplex, + * `concentration = alpha = [alpha_0, ..., alpha_{k-1}]`, `alpha_j > 0`, + * `Z` is the normalization constant aka the [multivariate beta function]( + https://en.wikipedia.org/wiki/Beta_function#Multivariate_beta_function), + and, + * `Gamma` is the [gamma function]( + https://en.wikipedia.org/wiki/Gamma_function). + + The `concentration` represents mean total counts of class occurrence, i.e., + + ```none + concentration = alpha = mean * total_concentration + ``` + where `mean` in `S^{k-1}` and `total_concentration` is a positive real number + representing a mean total count. - This class provides methods to create indexed batches of Dirichlet - distributions. If the provided `alpha` is rank 2 or higher, for - every fixed set of leading dimensions, the last dimension represents one - single Dirichlet distribution. When calling distribution - functions (e.g. `dist.prob(x)`), `alpha` and `x` are broadcast to the - same shape (if possible). In all cases, the last dimension of alpha/x - represents single Dirichlet distributions. + Distribution parameters are automatically broadcast in all functions; see + examples for details. #### Examples ```python - alpha = [1, 2, 3] + # Create a single trivariate Dirichlet, with the 3rd class being three times + # more frequent than the first. I.e., batch_shape=[], event_shape=[3]. + alpha = [1., 2, 3] dist = Dirichlet(alpha) - ``` - Creates a 3-class distribution, with the 3rd class is most likely to be drawn. - The distribution functions can be evaluated on x. + dist.sample([4, 5]) # shape: [4, 5, 3] - ```python - # x same shape as alpha. - x = [.2, .3, .5] - dist.prob(x) # Shape [] + # x has one sample, one batch, three classes: + x = [.2, .3, .5] # shape: [3] + dist.prob(x) # shape: [] - # alpha will be broadcast to [[1, 2, 3], [1, 2, 3]] to match x. - x = [[.1, .4, .5], [.2, .3, .5]] - dist.prob(x) # Shape [2] + # x has two samples from one batch: + x = [[.1, .4, .5], + [.2, .3, .5]] + dist.prob(x) # shape: [2] # alpha will be broadcast to shape [5, 7, 3] to match x. - x = [[...]] # Shape [5, 7, 3] - dist.prob(x) # Shape [5, 7] + x = [[...]] # shape: [5, 7, 3] + dist.prob(x) # shape: [5, 7] ``` - Creates a 2-batch of 3-class distributions. - ```python - alpha = [[1, 2, 3], [4, 5, 6]] # Shape [2, 3] + # Create batch_shape=[2], event_shape=[3]: + alpha = [[1., 2, 3], + [4, 5, 6]] # shape: [2, 3] dist = Dirichlet(alpha) - # x will be broadcast to [[2, 1, 0], [2, 1, 0]] to match alpha. + dist.sample([4, 5]) # shape: [4, 5, 2, 3] + x = [.2, .3, .5] - dist.prob(x) # Shape [2] + # x will be broadcast as [[.2, .3, .5], + # [.2, .3, .5]], + # thus matching batch_shape [2, 3]. + dist.prob(x) # shape: [2] ``` """ def __init__(self, - alpha, + concentration, validate_args=False, allow_nan_stats=True, name="Dirichlet"): """Initialize a batch of Dirichlet distributions. Args: - alpha: Positive floating point tensor with shape broadcastable to - `[N1,..., Nm, k]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm` - different `k` class Dirichlet distributions. - validate_args: `Boolean`, default `False`. Whether to assert valid values - for parameters `alpha` and `x` in `prob` and `log_prob`. If `False`, - correct behavior is not guaranteed. - allow_nan_stats: `Boolean`, default `True`. If `False`, raise an - exception if a statistic (e.g. mean/mode/etc...) is undefined for any - batch member. If `True`, batch members with valid parameters leading to - undefined statistics will return NaN for this statistic. - name: The name to prefix Ops created by this distribution class. - - Examples: - - ```python - # Define 1-batch of 2-class Dirichlet distributions, - # also known as a Beta distribution. - dist = Dirichlet([1.1, 2.0]) - - # Define a 2-batch of 3-class distributions. - dist = Dirichlet([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) - ``` - + concentration: Positive floating-point `Tensor` indicating mean number + of class occurrences; aka "alpha". Implies `self.dtype`, and + `self.batch_shape`, `self.event_shape`, i.e., if + `concentration.shape = [N1, N2, ..., Nm, k]` then + `batch_shape = [N1, N2, ..., Nm]` and + `event_shape = [k]`. + validate_args: Python `Boolean`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + allow_nan_stats: Python `Boolean`, default `True`. When `True`, statistics + (e.g., mean, mode, variance) use the value "`NaN`" to indicate the + result is undefined. When `False`, an exception is raised if one or + more of the statistic's batch members are undefined. + name: `String` name prefixed to Ops created by this class. """ parameters = locals() parameters.pop("self") - with ops.name_scope(name, values=[alpha]) as ns: - alpha = ops.convert_to_tensor(alpha, name="alpha") - with ops.control_dependencies([ - check_ops.assert_positive(alpha), - check_ops.assert_rank_at_least(alpha, 1) - ] if validate_args else []): - self._alpha = array_ops.identity(alpha, name="alpha") - self._alpha_sum = math_ops.reduce_sum(alpha, - reduction_indices=[-1], - keep_dims=False) + with ops.name_scope(name, values=[concentration]) as ns: + self._concentration = self._maybe_assert_valid_concentration( + ops.convert_to_tensor(concentration, name="concentration"), + validate_args) + self._total_concentration = math_ops.reduce_sum(self._concentration, -1) super(Dirichlet, self).__init__( - dtype=self._alpha.dtype, + dtype=self._concentration.dtype, validate_args=validate_args, allow_nan_stats=allow_nan_stats, is_continuous=True, reparameterization_type=distribution.NOT_REPARAMETERIZED, parameters=parameters, - graph_parents=[self._alpha, self._alpha_sum], + graph_parents=[self._concentration, + self._total_concentration], name=ns) @property - def alpha(self): - """Shape parameter.""" - return self._alpha + def concentration(self): + """Concentration parameter; expected counts for that coordinate.""" + return self._concentration @property - def alpha_sum(self): - """Sum of shape parameter.""" - return self._alpha_sum + def total_concentration(self): + """Sum of last dim of concentration parameter.""" + return self._total_concentration def _batch_shape_tensor(self): - return array_ops.shape(self.alpha_sum) + return array_ops.shape(self.total_concentration) def _batch_shape(self): - return self.alpha_sum.get_shape() + return self.total_concentration.get_shape() def _event_shape_tensor(self): - return array_ops.gather(array_ops.shape(self.alpha), - [array_ops.rank(self.alpha) - 1]) + return array_ops.shape(self.concentration)[-1:] def _event_shape(self): - return self.alpha.get_shape().with_rank_at_least(1)[-1:] + return self.concentration.get_shape().with_rank_at_least(1)[-1:] def _sample_n(self, n, seed=None): gamma_sample = random_ops.random_gamma( - [n,], self.alpha, dtype=self.dtype, seed=seed) - return gamma_sample / math_ops.reduce_sum( - gamma_sample, reduction_indices=[-1], keep_dims=True) + shape=[n], + alpha=self.concentration, + dtype=self.dtype, + seed=seed) + return gamma_sample / math_ops.reduce_sum(gamma_sample, -1, keep_dims=True) - @distribution_util.AppendDocstring(_dirichlet_prob_note) + @distribution_util.AppendDocstring(_dirichlet_sample_note) def _log_prob(self, x): - x = ops.convert_to_tensor(x, name="x") - x = self._assert_valid_sample(x) - unnorm_prob = (self.alpha - 1.) * math_ops.log(x) - log_prob = math_ops.reduce_sum( - unnorm_prob, reduction_indices=[-1], - keep_dims=False) - special_math_ops.lbeta(self.alpha) - return log_prob - - @distribution_util.AppendDocstring(_dirichlet_prob_note) + return self._log_unnormalized_prob(x) - self._log_normalization() + + @distribution_util.AppendDocstring(_dirichlet_sample_note) def _prob(self, x): return math_ops.exp(self._log_prob(x)) + def _log_unnormalized_prob(self, x): + x = self._maybe_assert_valid_sample(x) + return math_ops.reduce_sum((self.concentration - 1.) * math_ops.log(x), -1) + + def _log_normalization(self): + return special_math_ops.lbeta(self.concentration) + def _entropy(self): - entropy = special_math_ops.lbeta(self.alpha) - entropy += math_ops.digamma(self.alpha_sum) * ( - self.alpha_sum - - math_ops.cast(self.event_shape_tensor()[0], self.dtype)) - entropy += -math_ops.reduce_sum( - (self.alpha - 1.) * math_ops.digamma(self.alpha), - reduction_indices=[-1], - keep_dims=False) - return entropy + k = math_ops.cast(self.event_shape_tensor()[0], self.dtype) + return ( + self._log_normalization() + + ((self.total_concentration - k) + * math_ops.digamma(self.total_concentration)) + - math_ops.reduce_sum( + (self.concentration - 1.) * math_ops.digamma(self.concentration), + axis=-1)) def _mean(self): - return self.alpha / self.alpha_sum[..., None] + return self.concentration / self.total_concentration[..., None] def _covariance(self): x = self._variance_scale_term() * self._mean() @@ -227,37 +241,57 @@ class Dirichlet(distribution.Distribution): def _variance_scale_term(self): """Helper to `_covariance` and `_variance` which computes a shared scale.""" - return math_ops.rsqrt(1. + self.alpha_sum[..., None]) + return math_ops.rsqrt(1. + self.total_concentration[..., None]) @distribution_util.AppendDocstring( - """Note that the mode for the Dirichlet distribution is only defined - when `alpha > 1`. This returns the mode when `alpha > 1`, - and NaN otherwise. If `self.allow_nan_stats` is `False`, an exception - will be raised rather than returning `NaN`.""") + """Note: The mode is undefined when any `concentration <= 1`. If + `self.allow_nan_stats` is `True`, `NaN` is used for undefined modes. If + `self.allow_nan_stats` is `False` an exception is raised when one or more + modes are undefined.""") def _mode(self): - mode = ((self.alpha - 1.) / - (array_ops.expand_dims(self.alpha_sum, dim=-1) - - math_ops.cast(self.event_shape_tensor()[0], self.dtype))) + k = math_ops.cast(self.event_shape_tensor()[0], self.dtype) + mode = (self.concentration - 1.) / (self.total_concentration[..., None] - k) if self.allow_nan_stats: - nan = np.array(np.nan, dtype=self.dtype.as_numpy_dtype()) - shape = array_ops.concat([self.batch_shape_tensor(), - self.event_shape_tensor()], 0) + nan = array_ops.fill( + array_ops.shape(mode), + np.array(np.nan, dtype=self.dtype.as_numpy_dtype()), + name="nan") return array_ops.where( - math_ops.greater(self.alpha, 1.), - mode, - array_ops.fill(shape, nan, name="nan")) - else: - return control_flow_ops.with_dependencies([ - check_ops.assert_less( - array_ops.ones((), dtype=self.dtype), self.alpha, - message="mode not defined for components of alpha <= 1") - ], mode) - - def _assert_valid_sample(self, x): - if not self.validate_args: return x + math_ops.reduce_all(self.concentration > 1., axis=-1), + mode, nan) + return control_flow_ops.with_dependencies([ + check_ops.assert_less( + array_ops.ones([], self.dtype), + self.concentration, + message="Mode undefined when any concentration <= 1"), + ], mode) + + def _maybe_assert_valid_concentration(self, concentration, validate_args): + """Checks the validity of the concentration parameter.""" + if not validate_args: + return concentration + return control_flow_ops.with_dependencies([ + check_ops.assert_positive( + concentration, + message="Concentration parameter must be positive."), + check_ops.assert_rank_at_least( + concentration, 1, + message="Concentration parameter must have >=1 dimensions."), + check_ops.assert_less( + 1, array_ops.shape(concentration)[-1], + message="Concentration parameter must have event_size >= 2."), + ], concentration) + + def _maybe_assert_valid_sample(self, x): + """Checks the validity of a sample.""" + if not self.validate_args: + return x return control_flow_ops.with_dependencies([ - check_ops.assert_positive(x), + check_ops.assert_positive( + x, + message="samples must be positive"), distribution_util.assert_close( array_ops.ones((), dtype=self.dtype), - math_ops.reduce_sum(x, reduction_indices=[-1])), + math_ops.reduce_sum(x, -1), + message="sample last-dimension must sum to `1`"), ], x) diff --git a/tensorflow/contrib/distributions/python/ops/dirichlet_multinomial.py b/tensorflow/contrib/distributions/python/ops/dirichlet_multinomial.py index 84c11d3801..0ba5165759 100644 --- a/tensorflow/contrib/distributions/python/ops/dirichlet_multinomial.py +++ b/tensorflow/contrib/distributions/python/ops/dirichlet_multinomial.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""The Dirichlet Multinomial distribution class.""" +"""The DirichletMultinomial distribution class.""" from __future__ import absolute_import from __future__ import division @@ -30,52 +30,76 @@ from tensorflow.python.ops import random_ops from tensorflow.python.ops import special_math_ops -_dirichlet_multinomial_prob_note = """ -For each batch of counts `[n_1,...,n_k]`, `P[counts]` is the probability -that after sampling `n` draws from this Dirichlet Multinomial -distribution, the number of draws falling in class `j` is `n_j`. Note that -different sequences of draws can result in the same counts, thus the -probability includes a combinatorial coefficient. +__all__ = [ + "DirichletMultinomial", +] -Note that input, "counts", must be a non-negative tensor with dtype `dtype` -and whose shape can be broadcast with `self.alpha`. For fixed leading -dimensions, the last dimension represents counts for the corresponding -Dirichlet Multinomial distribution in `self.alpha`. `counts` is only legal if -it sums up to `n` and its components are equal to integer values. -""" + +_dirichlet_multinomial_sample_note = """For each batch of counts, +`value = [n_0, ... ,n_{k-1}]`, `P[value]` is the probability that after sampling +`self.total_count` draws from this Dirichlet-Multinomial distribution, the +number of draws falling in class `j` is `n_j`. Since this definition is +[exchangeable]( https://en.wikipedia.org/wiki/Exchangeable_random_variables); +different sequences have the same counts so the probability includes a +combinatorial coefficient. + +Note: `value` must be a non-negative tensor with dtype `self.dtype`, have no +fractional components, and such that +`tf.reduce_sum(value, -1) = self.total_count`. Its shape must be broadcastable +with `self.concentration` and `self.total_count`.""" class DirichletMultinomial(distribution.Distribution): - """DirichletMultinomial mixture distribution. + """Dirichlet-Multinomial compound distribution. + + The Dirichlet-Multinomial distribution is parameterized by a (batch of) + length-`k` `concentration` vectors (`k > 1`) and a `total_count` number of + trials, i.e., the number of trials per draw from the DirichletMultinomial. It + is defined over a (batch of) length-`k` vector `counts` such that + `tf.reduce_sum(counts, -1) = total_count`. The Dirichlet-Multinomial is + identically the Beta-Binomial distribution when `k = 2`. - This distribution is parameterized by a vector `alpha` of concentration - parameters for `k` classes and `n`, the counts per each class.. + #### Mathematical Details - #### Mathematical details + The Dirichlet-Multinomial is a distribution over `k`-class counts, i.e., a + length-`k` vector of non-negative integer `counts = n = [n_0, ..., n_{k-1}]`. + + The probability mass function (pmf) is, + + ```none + pmf(n; alpha, N) = Beta(alpha + n) / (prod_j n_j!) / Z + Z = Beta(alpha) / N! + ``` - The Dirichlet Multinomial is a distribution over k-class count data, meaning - for each k-tuple of non-negative integer `counts = [c_1,...,c_k]`, we have a - probability of these draws being made from the distribution. The distribution - has hyperparameters `alpha = (alpha_1,...,alpha_k)`, and probability mass - function (pmf): + where: - ```pmf(counts) = N! / (n_1!...n_k!) * Beta(alpha + c) / Beta(alpha)``` + * `concentration = alpha = [alpha_0, ..., alpha_{k-1}]`, `alpha_j > 0`, + * `total_count = N`, `N` a positive integer, + * `N!` is `N` factorial, and, + * `Beta(x) = prod_j Gamma(x_j) / Gamma(sum_j x_j)` is the + [multivariate beta function]( + https://en.wikipedia.org/wiki/Beta_function#Multivariate_beta_function), + and, + * `Gamma` is the [gamma function]( + https://en.wikipedia.org/wiki/Gamma_function). - where above `N = sum_j n_j`, `N!` is `N` factorial, and - `Beta(x) = prod_j Gamma(x_j) / Gamma(sum_j x_j)` is the multivariate beta - function. + Dirichlet-Multinomial is a [compound distribution]( + https://en.wikipedia.org/wiki/Compound_probability_distribution), i.e., its + samples are generated as follows. - This is a mixture distribution in that `M` samples can be produced by: - 1. Choose class probabilities `p = (p_1,...,p_k) ~ Dir(alpha)` - 2. Draw integers `m = (n_1,...,n_k) ~ Multinomial(N, p)` + 1. Choose class probabilities: + `probs = [p_0,...,p_{k-1}] ~ Dir(concentration)` + 2. Draw integers: + `counts = [n_0,...,n_{k-1}] ~ Multinomial(total_count, probs)` - This class provides methods to create indexed batches of Dirichlet - Multinomial distributions. If the provided `alpha` is rank 2 or higher, for - every fixed set of leading dimensions, the last dimension represents one - single Dirichlet Multinomial distribution. When calling distribution - functions (e.g. `dist.prob(counts)`), `alpha` and `counts` are broadcast to - the same shape (if possible). In all cases, the last dimension of - alpha/counts represents single Dirichlet Multinomial distributions. + The last `concentration` dimension parametrizes a single Dirichlet-Multinomial + distribution. When calling distribution functions (e.g., `dist.prob(counts)`), + `concentration`, `total_count` and `counts` are broadcast to the same shape. + The last dimension of of `counts` corresponds single Dirichlet-Multinomial + distributions. + + Distribution parameters are automatically broadcast in all functions; see + examples for details. #### Examples @@ -116,116 +140,102 @@ class DirichletMultinomial(distribution.Distribution): """ - # TODO(b/27419586) Change docstring for dtype of alpha once int allowed. + # TODO(b/27419586) Change docstring for dtype of concentration once int + # allowed. def __init__(self, - n, - alpha, + total_count, + concentration, validate_args=False, allow_nan_stats=True, name="DirichletMultinomial"): """Initialize a batch of DirichletMultinomial distributions. Args: - n: Non-negative floating point tensor, whose dtype is the same as - `alpha`. The shape is broadcastable to `[N1,..., Nm]` with `m >= 0`. - Defines this as a batch of `N1 x ... x Nm` different Dirichlet - multinomial distributions. Its components should be equal to integer - values. - alpha: Positive floating point tensor, whose dtype is the same as - `n` with shape broadcastable to `[N1,..., Nm, k]` `m >= 0`. Defines - this as a batch of `N1 x ... x Nm` different `k` class Dirichlet + total_count: Non-negative floating point tensor, whose dtype is the same + as `concentration`. The shape is broadcastable to `[N1,..., Nm]` with + `m >= 0`. Defines this as a batch of `N1 x ... x Nm` different + Dirichlet multinomial distributions. Its components should be equal to + integer values. + concentration: Positive floating point tensor, whose dtype is the + same as `n` with shape broadcastable to `[N1,..., Nm, k]` `m >= 0`. + Defines this as a batch of `N1 x ... x Nm` different `k` class Dirichlet multinomial distributions. - validate_args: `Boolean`, default `False`. Whether to assert valid - values for parameters `alpha` and `n`, and `x` in `prob` and - `log_prob`. If `False`, correct behavior is not guaranteed. - allow_nan_stats: `Boolean`, default `True`. If `False`, raise an - exception if a statistic (e.g. mean/mode/etc...) is undefined for any - batch member. If `True`, batch members with valid parameters leading to - undefined statistics will return NaN for this statistic. - name: The name to prefix Ops created by this distribution class. - - Examples: - - ```python - # Define 1-batch of 2-class Dirichlet multinomial distribution, - # also known as a beta-binomial. - dist = DirichletMultinomial(2.0, [1.1, 2.0]) - - # Define a 2-batch of 3-class distributions. - dist = DirichletMultinomial([3., 4], [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) - ``` - + validate_args: Python `Boolean`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + allow_nan_stats: Python `Boolean`, default `True`. When `True`, statistics + (e.g., mean, mode, variance) use the value "`NaN`" to indicate the + result is undefined. When `False`, an exception is raised if one or + more of the statistic's batch members are undefined. + name: `String` name prefixed to Ops created by this class. """ parameters = locals() parameters.pop("self") - with ops.name_scope(name, values=[n, alpha]) as ns: + with ops.name_scope(name, values=[total_count, concentration]) as ns: # Broadcasting works because: # * The broadcasting convention is to prepend dimensions of size [1], and - # we use the last dimension for the distribution, wherease + # we use the last dimension for the distribution, whereas # the batch dimensions are the leading dimensions, which forces the # distribution dimension to be defined explicitly (i.e. it cannot be # created automatically by prepending). This forces enough - # explicitivity. - # * All calls involving `counts` eventually require a broadcast between - # `counts` and alpha. - self._alpha = self._assert_valid_alpha(alpha, validate_args) - self._n = self._assert_valid_n(n, validate_args) - self._alpha_sum = math_ops.reduce_sum( - self._alpha, reduction_indices=[-1], keep_dims=False) + # explicitness. + # * All calls involving `counts` eventually require a broadcast between + # `counts` and concentration. + self._total_count = self._maybe_assert_valid_total_count( + ops.convert_to_tensor(total_count, name="total_count"), + validate_args) + self._concentration = self._maybe_assert_valid_concentration( + ops.convert_to_tensor(concentration, + name="concentration"), + validate_args) + self._total_concentration = math_ops.reduce_sum(self._concentration, -1) super(DirichletMultinomial, self).__init__( - dtype=self._alpha.dtype, - is_continuous=False, - reparameterization_type=distribution.NOT_REPARAMETERIZED, + dtype=self._concentration.dtype, validate_args=validate_args, allow_nan_stats=allow_nan_stats, + is_continuous=False, + reparameterization_type=distribution.NOT_REPARAMETERIZED, parameters=parameters, - graph_parents=[self._alpha, self._n, self._alpha_sum], + graph_parents=[self._total_count, + self._concentration], name=ns) @property - def n(self): - """Parameter defining this distribution.""" - return self._n + def total_count(self): + """Number of trials used to construct a sample.""" + return self._total_count @property - def alpha(self): - """Parameter defining this distribution.""" - return self._alpha + def concentration(self): + """Concentration parameter; expected prior counts for that coordinate.""" + return self._concentration @property - def alpha_sum(self): - """Summation of alpha parameter.""" - return self._alpha_sum + def total_concentration(self): + """Sum of last dim of concentration parameter.""" + return self._total_concentration def _batch_shape_tensor(self): - return array_ops.shape(self.alpha_sum) + return array_ops.shape(self.total_concentration) def _batch_shape(self): - return self.alpha_sum.get_shape() + return self.total_concentration.get_shape() def _event_shape_tensor(self): - return array_ops.reverse_v2(array_ops.shape(self.alpha), [0])[0] + return array_ops.shape(self.concentration)[-1:] def _event_shape(self): - # Event shape depends only on alpha, not "n". - return self.alpha.get_shape().with_rank_at_least(1)[-1:] + # Event shape depends only on total_concentration, not "n". + return self.concentration.get_shape().with_rank_at_least(1)[-1:] def _sample_n(self, n, seed=None): - n_draws = math_ops.cast(self.n, dtype=dtypes.int32) - if self.n.get_shape().ndims is not None: - if self.n.get_shape().ndims != 0: - raise NotImplementedError( - "Sample only supported for scalar number of draws.") - elif self.validate_args: - is_scalar = check_ops.assert_rank( - n_draws, 0, - message="Sample only supported for scalar number of draws.") - n_draws = control_flow_ops.with_dependencies([is_scalar], n_draws) + n_draws = math_ops.cast(self.total_count, dtype=dtypes.int32) k = self.event_shape_tensor()[0] unnormalized_logits = array_ops.reshape( math_ops.log(random_ops.random_gamma( shape=[n], - alpha=self.alpha, + alpha=self.concentration, dtype=self.dtype, seed=seed)), shape=[-1, k]) @@ -233,40 +243,41 @@ class DirichletMultinomial(distribution.Distribution): logits=unnormalized_logits, num_samples=n_draws, seed=distribution_util.gen_new_seed(seed, salt="dirichlet_multinomial")) - x = math_ops.reduce_sum(array_ops.one_hot(draws, depth=k), - reduction_indices=-2) + x = math_ops.reduce_sum(array_ops.one_hot(draws, depth=k), -2) final_shape = array_ops.concat([[n], self.batch_shape_tensor(), [k]], 0) return array_ops.reshape(x, final_shape) - @distribution_util.AppendDocstring(_dirichlet_multinomial_prob_note) + @distribution_util.AppendDocstring(_dirichlet_multinomial_sample_note) def _log_prob(self, counts): - counts = self._assert_valid_counts(counts) - ordered_prob = (special_math_ops.lbeta(self.alpha + counts) - - special_math_ops.lbeta(self.alpha)) - log_prob = ordered_prob + distribution_util.log_combinations( - self.n, counts) - return log_prob - - @distribution_util.AppendDocstring(_dirichlet_multinomial_prob_note) + counts = self._maybe_assert_valid_sample(counts) + ordered_prob = ( + special_math_ops.lbeta(self.concentration + counts) + - special_math_ops.lbeta(self.concentration)) + return ordered_prob + distribution_util.log_combinations( + self.total_count, counts) + + @distribution_util.AppendDocstring(_dirichlet_multinomial_sample_note) def _prob(self, counts): return math_ops.exp(self._log_prob(counts)) def _mean(self): - return self.n * (self.alpha / self.alpha_sum[..., None]) + return self.total_count * (self.concentration / + self.total_concentration[..., None]) @distribution_util.AppendDocstring( """The covariance for each batch member is defined as the following: - ``` + ```none Var(X_j) = n * alpha_j / alpha_0 * (1 - alpha_j / alpha_0) * (n + alpha_0) / (1 + alpha_0) ``` - where `alpha_0 = sum_j alpha_j`. + where `concentration = alpha` and + `total_concentration = alpha_0 = sum_j alpha_j`. The covariance between elements in a batch is defined as: - ``` + ```none Cov(X_i, X_j) = -n * alpha_i * alpha_j / alpha_0 ** 2 * (n + alpha_0) / (1 + alpha_0) ``` @@ -280,40 +291,55 @@ class DirichletMultinomial(distribution.Distribution): def _variance(self): scale = self._variance_scale_term() x = scale * self._mean() - return x * (self.n * scale - x) + return x * (self.total_count * scale - x) def _variance_scale_term(self): """Helper to `_covariance` and `_variance` which computes a shared scale.""" # We must take care to expand back the last dim whenever we use the - # alpha_sum. - c0 = self.alpha_sum[..., None] - return math_ops.sqrt((1. + c0 / self.n) / (1. + c0)) + # total_concentration. + c0 = self.total_concentration[..., None] + return math_ops.sqrt((1. + c0 / self.total_count) / (1. + c0)) - def _assert_valid_counts(self, counts): + def _maybe_assert_valid_concentration(self, concentration, validate_args): + """Checks the validity of the concentration parameter.""" + if not validate_args: + return concentration + return control_flow_ops.with_dependencies([ + check_ops.assert_positive( + concentration, + message="Concentration parameter must be positive."), + check_ops.assert_rank_at_least( + concentration, 1, + message="Concentration parameter must have >=1 dimensions."), + check_ops.assert_less( + 1, array_ops.shape(concentration)[-1], + message="Concentration parameter must have event_size >= 2."), + ], concentration) + + def _maybe_assert_valid_total_count(self, total_count, validate_args): + if not validate_args: + return total_count + return control_flow_ops.with_dependencies([ + check_ops.assert_non_negative( + total_count, + message="total_count must be non-negative."), + distribution_util.assert_integer_form( + total_count, + message="total_count cannot contain fractional values."), + ], total_count) + + def _maybe_assert_valid_sample(self, counts): """Check counts for proper shape, values, then return tensor version.""" - counts = ops.convert_to_tensor(counts, name="counts") if not self.validate_args: return counts - candidate_n = math_ops.reduce_sum(counts, reduction_indices=[-1]) return control_flow_ops.with_dependencies([ - check_ops.assert_non_negative(counts), + check_ops.assert_non_negative( + counts, + message="counts must be non-negative."), check_ops.assert_equal( - self._n, candidate_n, - message="counts do not sum to n"), - distribution_util.assert_integer_form(counts)], counts) - - def _assert_valid_alpha(self, alpha, validate_args): - alpha = ops.convert_to_tensor(alpha, name="alpha") - if not validate_args: - return alpha - return control_flow_ops.with_dependencies( - [check_ops.assert_rank_at_least(alpha, 1), - check_ops.assert_positive(alpha)], alpha) - - def _assert_valid_n(self, n, validate_args): - n = ops.convert_to_tensor(n, name="n") - if not validate_args: - return n - return control_flow_ops.with_dependencies( - [check_ops.assert_non_negative(n), - distribution_util.assert_integer_form(n)], n) + self.total_count, math_ops.reduce_sum(counts, -1), + message="counts last-dimension must sum to `self.total_count`"), + distribution_util.assert_integer_form( + counts, + message="counts cannot contain fractional components."), + ], counts) diff --git a/tensorflow/contrib/distributions/python/ops/exponential.py b/tensorflow/contrib/distributions/python/ops/exponential.py index 79301436c0..23cc63df59 100644 --- a/tensorflow/contrib/distributions/python/ops/exponential.py +++ b/tensorflow/contrib/distributions/python/ops/exponential.py @@ -29,36 +29,64 @@ from tensorflow.python.ops import nn from tensorflow.python.ops import random_ops +__all__ = [ + "Exponential", + "ExponentialWithSoftplusRate", +] + + class Exponential(gamma.Gamma): - """The Exponential distribution with rate parameter lam. + """Exponential distribution. + + The Exponential distribution is parameterized by an event `rate` parameter. + + #### Mathematical Details + + The probability density function (pdf) is, + + ```none + pdf(x; lambda, x > 0) = exp(-lambda x) / Z + Z = 1 / lambda + ``` + + where `rate = lambda` and `Z` is the normalizaing constant. + + The Exponential distribution is a special case of the Gamma distribution, + i.e., + + ```python + Exponential(rate) = Gamma(concentration=1., rate) + ``` - The PDF of this distribution is: + The Exponential distribution uses a `rate` parameter, or "inverse scale", + which can be intuited as, - ```prob(x) = (lam * e^(-lam * x)), x > 0``` + ```none + X ~ Exponential(rate=1) + Y = X / rate + ``` - Note that the Exponential distribution is a special case of the Gamma - distribution, with Exponential(lam) = Gamma(1, lam). """ def __init__(self, - lam, + rate, validate_args=False, allow_nan_stats=True, name="Exponential"): - """Construct Exponential distribution with parameter `lam`. + """Construct Exponential distribution with parameter `rate`. Args: - lam: Floating point tensor, the rate of the distribution(s). - `lam` must contain only positive values. - validate_args: `Boolean`, default `False`. Whether to assert that - `lam > 0`, and that `x > 0` in the methods `prob(x)` and `log_prob(x)`. - If `validate_args` is `False` and the inputs are invalid, correct - behavior is not guaranteed. - allow_nan_stats: `Boolean`, default `True`. If `False`, raise an - exception if a statistic (e.g. mean/mode/etc...) is undefined for any - batch member. If `True`, batch members with valid parameters leading to - undefined statistics will return NaN for this statistic. - name: The name to prepend to all ops created by this distribution. + rate: Floating point tensor, equivalent to `1 / mean`. Must contain only + positive values. + validate_args: Python `Boolean`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + allow_nan_stats: Python `Boolean`, default `True`. When `True`, statistics + (e.g., mean, mode, variance) use the value "`NaN`" to indicate the + result is undefined. When `False`, an exception is raised if one or + more of the statistic's batch members are undefined. + name: `String` name prefixed to Ops created by this class. """ parameters = locals() parameters.pop("self") @@ -66,30 +94,30 @@ class Exponential(gamma.Gamma): # true in the parent class "Gamma." Therefore, passing # allow_nan_stats=True # through to the parent class results in unnecessary asserts. - with ops.name_scope(name, values=[lam]) as ns: - self._lam = ops.convert_to_tensor(lam, name="lam") + with ops.name_scope(name, values=[rate]) as ns: + self._rate = ops.convert_to_tensor(rate, name="rate") super(Exponential, self).__init__( - alpha=array_ops.ones((), dtype=self._lam.dtype), - beta=self._lam, + concentration=array_ops.ones((), dtype=self._rate.dtype), + rate=self._rate, allow_nan_stats=allow_nan_stats, validate_args=validate_args, name=ns) - # While the Gamma distribution is not re-parameterizable, the - # exponential distribution is. + # While the Gamma distribution is not reparameterizable, the exponential + # distribution is. self._reparameterization_type = True self._parameters = parameters - self._graph_parents += [self._lam] + self._graph_parents += [self._rate] @staticmethod def _param_shapes(sample_shape): - return {"lam": ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)} + return {"rate": ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)} @property - def lam(self): - return self._lam + def rate(self): + return self._rate def _sample_n(self, n, seed=None): - shape = array_ops.concat(([n], array_ops.shape(self._lam)), 0) + shape = array_ops.concat(([n], array_ops.shape(self._rate)), 0) # Sample uniformly-at-random from the open-interval (0, 1). sampled = random_ops.random_uniform( shape, @@ -98,22 +126,22 @@ class Exponential(gamma.Gamma): maxval=array_ops.ones((), dtype=self.dtype), seed=seed, dtype=self.dtype) - return -math_ops.log(sampled) / self._lam + return -math_ops.log(sampled) / self._rate -class ExponentialWithSoftplusLam(Exponential): - """Exponential with softplus transform on `lam`.""" +class ExponentialWithSoftplusRate(Exponential): + """Exponential with softplus transform on `rate`.""" def __init__(self, - lam, + rate, validate_args=False, allow_nan_stats=True, - name="ExponentialWithSoftplusLam"): + name="ExponentialWithSoftplusRate"): parameters = locals() parameters.pop("self") - with ops.name_scope(name, values=[lam]) as ns: - super(ExponentialWithSoftplusLam, self).__init__( - lam=nn.softplus(lam, name="softplus_lam"), + with ops.name_scope(name, values=[rate]) as ns: + super(ExponentialWithSoftplusRate, self).__init__( + rate=nn.softplus(rate, name="softplus_rate"), validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=ns) diff --git a/tensorflow/contrib/distributions/python/ops/gamma.py b/tensorflow/contrib/distributions/python/ops/gamma.py index d7fc3a86f6..a8e1173d98 100644 --- a/tensorflow/contrib/distributions/python/ops/gamma.py +++ b/tensorflow/contrib/distributions/python/ops/gamma.py @@ -36,107 +36,143 @@ from tensorflow.python.ops import nn from tensorflow.python.ops import random_ops +__all__ = [ + "Gamma", + "GammaWithSoftplusConcentrationRate", +] + + class Gamma(distribution.Distribution): - """The `Gamma` distribution with parameter alpha and beta. + """Gamma distribution. + + The Gamma distribution is defined over positive real numbers using + parameters `concentration` (aka "alpha") and `rate` (aka "beta"). + + #### Mathematical Details + + The probability density function (pdf) is, - The parameters are the shape and inverse scale parameters alpha, beta. + ```none + pdf(x; alpha, beta, x > 0) = x**(alpha - 1) exp(-x beta) / Z + Z = Gamma(alpha) beta**alpha + ``` + + where: + + * `concentration = alpha`, `alpha > 0`, + * `rate = beta`, `beta > 0`, + * `Z` is the normalizing constant, and, + * `Gamma` is the [gamma function]( + https://en.wikipedia.org/wiki/Gamma_function). - The PDF of this distribution is: + The cumulative density function (cdf) is, + + ```none + cdf(x; alpha, beta, x > 0) = GammaInc(alpha, beta x) / Gamma(alpha) + ``` - ```pdf(x) = (beta^alpha)(x^(alpha-1))e^(-x*beta)/Gamma(alpha), x > 0``` + where `GammaInc` is the [lower incomplete Gamma function]( + https://en.wikipedia.org/wiki/Incomplete_gamma_function). - and the CDF of this distribution is: + The parameters can be intuited via their relationship to mean and stddev, - ```cdf(x) = GammaInc(alpha, beta * x) / Gamma(alpha), x > 0``` + ```none + concentration = alpha = (mean / stddev)**2 + rate = beta = mean / stddev**2 = concentration / mean + ``` - where GammaInc is the incomplete lower Gamma function. + Distribution parameters are automatically broadcast in all functions; see + examples for details. - WARNING: This distribution may draw 0-valued samples for small alpha values. - See the note on `tf.random_gamma`. + WARNING: This distribution may draw 0-valued samples for small `concentration` + values. See note in `tf.random_gamma` docstring. - Examples: + #### Examples ```python - dist = Gamma(alpha=3.0, beta=2.0) - dist2 = Gamma(alpha=[3.0, 4.0], beta=[2.0, 3.0]) + dist = Gamma(concentration=3.0, rate=2.0) + dist2 = Gamma(concentration=[3.0, 4.0], rate=[2.0, 3.0]) ``` """ def __init__(self, - alpha, - beta, + concentration, + rate, validate_args=False, allow_nan_stats=True, name="Gamma"): - """Construct Gamma distributions with parameters `alpha` and `beta`. + """Construct Gamma with `concentration` and `rate` parameters. - The parameters `alpha` and `beta` must be shaped in a way that supports - broadcasting (e.g. `alpha + beta` is a valid operation). + The parameters `concentration` and `rate` must be shaped in a way that + supports broadcasting (e.g. `concentration + rate` is a valid operation). Args: - alpha: Floating point tensor, the shape params of the - distribution(s). - alpha must contain only positive values. - beta: Floating point tensor, the inverse scale params of the - distribution(s). - beta must contain only positive values. - validate_args: `Boolean`, default `False`. Whether to assert that - `a > 0`, `b > 0`, and that `x > 0` in the methods `prob(x)` and - `log_prob(x)`. If `validate_args` is `False` and the inputs are - invalid, correct behavior is not guaranteed. - allow_nan_stats: `Boolean`, default `True`. If `False`, raise an - exception if a statistic (e.g. mean/mode/etc...) is undefined for any - batch member. If `True`, batch members with valid parameters leading to - undefined statistics will return NaN for this statistic. - name: The name to prepend to all ops created by this distribution. + concentration: Floating point tensor, the concentration params of the + distribution(s). Must contain only positive values. + rate: Floating point tensor, the inverse scale params of the + distribution(s). Must contain only positive values. + validate_args: Python `Boolean`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + allow_nan_stats: Python `Boolean`, default `True`. When `True`, statistics + (e.g., mean, mode, variance) use the value "`NaN`" to indicate the + result is undefined. When `False`, an exception is raised if one or + more of the statistic's batch members are undefined. + name: `String` name prefixed to Ops created by this class. Raises: - TypeError: if `alpha` and `beta` are different dtypes. + TypeError: if `concentration` and `rate` are different dtypes. """ parameters = locals() parameters.pop("self") - with ops.name_scope(name, values=[alpha, beta]) as ns: + with ops.name_scope(name, values=[concentration, rate]) as ns: with ops.control_dependencies([ - check_ops.assert_positive(alpha), - check_ops.assert_positive(beta), + check_ops.assert_positive(concentration), + check_ops.assert_positive(rate), ] if validate_args else []): - self._alpha = array_ops.identity(alpha, name="alpha") - self._beta = array_ops.identity(beta, name="beta") - contrib_tensor_util.assert_same_float_dtype((self._alpha, self._beta)) + self._concentration = array_ops.identity( + concentration, name="concentration") + self._rate = array_ops.identity(rate, name="rate") + contrib_tensor_util.assert_same_float_dtype( + [self._concentration, self._rate]) super(Gamma, self).__init__( - dtype=self._alpha.dtype, + dtype=self._concentration.dtype, validate_args=validate_args, allow_nan_stats=allow_nan_stats, is_continuous=True, reparameterization_type=distribution.NOT_REPARAMETERIZED, parameters=parameters, - graph_parents=[self._alpha, self._beta], + graph_parents=[self._concentration, + self._rate], name=ns) @staticmethod def _param_shapes(sample_shape): return dict( - zip(("alpha", "beta"), ([ops.convert_to_tensor( + zip(("concentration", "rate"), ([ops.convert_to_tensor( sample_shape, dtype=dtypes.int32)] * 2))) @property - def alpha(self): - """Shape parameter.""" - return self._alpha + def concentration(self): + """Concentration parameter.""" + return self._concentration @property - def beta(self): - """Inverse scale parameter.""" - return self._beta + def rate(self): + """Rate parameter.""" + return self._rate def _batch_shape_tensor(self): return array_ops.broadcast_dynamic_shape( - array_ops.shape(self.alpha), array_ops.shape(self.beta)) + array_ops.shape(self.concentration), + array_ops.shape(self.rate)) def _batch_shape(self): return array_ops.broadcast_static_shape( - self.alpha.get_shape(), self.beta.get_shape()) + self.concentration.get_shape(), + self.rate.get_shape()) def _event_shape_tensor(self): return constant_op.constant([], dtype=dtypes.int32) @@ -144,99 +180,101 @@ class Gamma(distribution.Distribution): def _event_shape(self): return tensor_shape.scalar() + @distribution_util.AppendDocstring( + """Note: See `tf.random_gamma` docstring for sampling details and + caveats.""") def _sample_n(self, n, seed=None): - """See the documentation for tf.random_gamma for more details.""" - return random_ops.random_gamma([n], - self.alpha, - beta=self.beta, - dtype=self.dtype, - seed=seed) + return random_ops.random_gamma( + shape=[n], + alpha=self.concentration, + beta=self.rate, + dtype=self.dtype, + seed=seed) def _log_prob(self, x): - x = control_flow_ops.with_dependencies([check_ops.assert_positive(x)] if - self.validate_args else [], x) - contrib_tensor_util.assert_same_float_dtype(tensors=[x], - dtype=self.dtype) - return (self.alpha * math_ops.log(self.beta) + - (self.alpha - 1.) * math_ops.log(x) - - self.beta * x - - math_ops.lgamma(self.alpha)) + return self._log_unnormalized_prob(x) - self._log_normalization() def _prob(self, x): return math_ops.exp(self._log_prob(x)) def _log_cdf(self, x): - x = control_flow_ops.with_dependencies([check_ops.assert_positive(x)] if - self.validate_args else [], x) - contrib_tensor_util.assert_same_float_dtype(tensors=[x], dtype=self.dtype) - # Note that igamma returns the regularized incomplete gamma function, - # which is what we want for the CDF. - return math_ops.log(math_ops.igamma(self.alpha, self.beta * x)) + return math_ops.log(self._cdf(x)) def _cdf(self, x): - return math_ops.igamma(self.alpha, self.beta * x) + x = self._maybe_assert_valid_sample(x) + # Note that igamma returns the regularized incomplete gamma function, + # which is what we want for the CDF. + return math_ops.igamma(self.concentration, self.rate * x) - @distribution_util.AppendDocstring( - """This is defined to be + def _log_unnormalized_prob(self, x): + x = self._maybe_assert_valid_sample(x) + return (self.concentration - 1.) * math_ops.log(x) - self.rate * x - ``` - entropy = alpha - log(beta) + log(Gamma(alpha)) - + (1-alpha)digamma(alpha) - ``` + def _log_normalization(self): + return (math_ops.lgamma(self.concentration) + - self.concentration * math_ops.log(self.rate)) - where digamma(alpha) is the digamma function. - """) def _entropy(self): - return (self.alpha - - math_ops.log(self.beta) + - math_ops.lgamma(self.alpha) + - (1. - self.alpha) * math_ops.digamma(self.alpha)) + return (self.concentration + - math_ops.log(self.rate) + + math_ops.lgamma(self.concentration) + + ((1. - self.concentration) * + math_ops.digamma(self.concentration))) def _mean(self): - return self.alpha / self.beta + return self.concentration / self.rate def _variance(self): - return self.alpha / math_ops.square(self.beta) + return self.concentration / math_ops.square(self.rate) def _stddev(self): - return math_ops.sqrt(self.alpha) / self.beta + return math_ops.sqrt(self.concentration) / self.rate @distribution_util.AppendDocstring( - """The mode of a gamma distribution is `(alpha - 1) / beta` when - `alpha > 1`, and `NaN` otherwise. If `self.allow_nan_stats` is `False`, + """The mode of a gamma distribution is `(shape - 1) / rate` when + `shape > 1`, and `NaN` otherwise. If `self.allow_nan_stats` is `False`, an exception will be raised rather than returning `NaN`.""") def _mode(self): - mode = (self.alpha - 1.) / self.beta + mode = (self.concentration - 1.) / self.rate if self.allow_nan_stats: - nan = np.array(np.nan, dtype=self.dtype.as_numpy_dtype()) - return array_ops.where( - self.alpha >= 1., - mode, - array_ops.fill(self.batch_shape_tensor(), nan, name="nan")) + nan = array_ops.fill( + self.batch_shape_tensor(), + np.array(np.nan, dtype=self.dtype.as_numpy_dtype()), + name="nan") + return array_ops.where(self.concentration > 1., mode, nan) else: return control_flow_ops.with_dependencies([ check_ops.assert_less( - array_ops.ones((), self.dtype), - self.alpha, - message="mode not defined for components of alpha <= 1"), + array_ops.ones([], self.dtype), + self.concentration, + message="mode not defined when any concentration <= 1"), ], mode) + def _maybe_assert_valid_sample(self, x): + contrib_tensor_util.assert_same_float_dtype(tensors=[x], dtype=self.dtype) + if not self.validate_args: + return x + return control_flow_ops.with_dependencies([ + check_ops.assert_positive(x), + ], x) + -class GammaWithSoftplusAlphaBeta(Gamma): - """Gamma with softplus transform on `alpha` and `beta`.""" +class GammaWithSoftplusConcentrationRate(Gamma): + """`Gamma` with softplus of `concentration` and `rate`.""" def __init__(self, - alpha, - beta, + concentration, + rate, validate_args=False, allow_nan_stats=True, - name="GammaWithSoftplusAlphaBeta"): + name="GammaWithSoftplusConcentrationRate"): parameters = locals() parameters.pop("self") - with ops.name_scope(name, values=[alpha, beta]) as ns: - super(GammaWithSoftplusAlphaBeta, self).__init__( - alpha=nn.softplus(alpha, name="softplus_alpha"), - beta=nn.softplus(beta, name="softplus_beta"), + with ops.name_scope(name, values=[concentration, rate]) as ns: + super(GammaWithSoftplusConcentrationRate, self).__init__( + concentration=nn.softplus(concentration, + name="softplus_concentration"), + rate=nn.softplus(rate, name="softplus_rate"), validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=ns) @@ -256,15 +294,16 @@ def _kl_gamma_gamma(g0, g1, name=None): Returns: kl_gamma_gamma: `Tensor`. The batchwise KL(g0 || g1). """ - with ops.name_scope(name, "kl_gamma_gamma", - values=[g0.alpha, g0.beta, g1.alpha, g1.beta]): + with ops.name_scope(name, "kl_gamma_gamma", values=[ + g0.concentration, g0.rate, g1.concentration, g1.rate]): # Result from: # http://www.fil.ion.ucl.ac.uk/~wpenny/publications/densities.ps # For derivation see: # http://stats.stackexchange.com/questions/11646/kullback-leibler-divergence-between-two-gamma-distributions pylint: disable=line-too-long - return ((g0.alpha - g1.alpha) * math_ops.digamma(g0.alpha) - + math_ops.lgamma(g1.alpha) - - math_ops.lgamma(g0.alpha) - + g1.alpha * math_ops.log(g0.beta) - - g1.alpha * math_ops.log(g1.beta) - + g0.alpha * (g1.beta / g0.beta - 1.)) + return (((g0.concentration - g1.concentration) + * math_ops.digamma(g0.concentration)) + + math_ops.lgamma(g1.concentration) + - math_ops.lgamma(g0.concentration) + + g1.concentration * math_ops.log(g0.rate) + - g1.concentration * math_ops.log(g1.rate) + + g0.concentration * (g1.rate / g0.rate - 1.)) diff --git a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py index 9fa6e55fa8..3bfc169c6b 100644 --- a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py +++ b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py @@ -22,6 +22,7 @@ import numpy as np from tensorflow.contrib.distributions.python.ops import distribution from tensorflow.contrib.distributions.python.ops import distribution_util +from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -34,102 +35,144 @@ from tensorflow.python.ops import nn from tensorflow.python.ops import random_ops +__all__ = [ + "InverseGamma", + "InverseGammaWithSoftplusConcentrationRate", +] + + class InverseGamma(distribution.Distribution): - """The `InverseGamma` distribution with parameter alpha and beta. + """InverseGamma distribution. - The parameters are the shape and inverse scale parameters alpha, beta. + The `InverseGamma` distribution is defined over positive real numbers using + parameters `concentration` (aka "alpha") and `rate` (aka "beta"). - The PDF of this distribution is: + #### Mathematical Details - ```pdf(x) = (beta^alpha)/Gamma(alpha)(x^(-alpha-1))e^(-beta/x), x > 0``` + The probability density function (pdf) is, - and the CDF of this distribution is: + ```none + pdf(x; alpha, beta, x > 0) = x**(-alpha - 1) exp(-beta / x) / Z + Z = Gamma(alpha) beta**-alpha + ``` - ```cdf(x) = GammaInc(alpha, beta / x) / Gamma(alpha), x > 0``` + where: - where GammaInc is the upper incomplete Gamma function. + * `concentration = alpha`, + * `rate = beta`, + * `Z` is the normalizing constant, and, + * `Gamma` is the [gamma function]( + https://en.wikipedia.org/wiki/Gamma_function). - Examples: + The cumulative density function (cdf) is, + + ```none + cdf(x; alpha, beta, x > 0) = GammaInc(alpha, beta / x) / Gamma(alpha) + ``` + + where `GammaInc` is the [upper incomplete Gamma function]( + https://en.wikipedia.org/wiki/Incomplete_gamma_function). + + The parameters can be intuited via their relationship to mean and stddev, + + ```none + concentration = alpha = (mean / stddev)**2 + rate = beta = mean / stddev**2 + ``` + + Distribution parameters are automatically broadcast in all functions; see + examples for details. + + WARNING: This distribution may draw 0-valued samples for small concentration + values. See note in `tf.random_gamma` docstring. + + #### Examples ```python - dist = InverseGamma(alpha=3.0, beta=2.0) - dist2 = InverseGamma(alpha=[3.0, 4.0], beta=[2.0, 3.0]) + dist = InverseGamma(concentration=3.0, rate=2.0) + dist2 = InverseGamma(concentration=[3.0, 4.0], rate=[2.0, 3.0]) ``` """ def __init__(self, - alpha, - beta, + concentration, + rate, validate_args=False, allow_nan_stats=True, name="InverseGamma"): - """Construct InverseGamma distributions with parameters `alpha` and `beta`. + """Construct InverseGamma with `concentration` and `rate` parameters. - The parameters `alpha` and `beta` must be shaped in a way that supports - broadcasting (e.g. `alpha + beta` is a valid operation). + The parameters `concentration` and `rate` must be shaped in a way that + supports broadcasting (e.g. `concentration + rate` is a valid operation). Args: - alpha: Floating point tensor, the shape params of the - distribution(s). - alpha must contain only positive values. - beta: Floating point tensor, the scale params of the distribution(s). - beta must contain only positive values. - validate_args: `Boolean`, default `False`. Whether to assert that - `a > 0`, `b > 0`, and that `x > 0` in the methods `prob(x)` and - `log_prob(x)`. If `validate_args` is `False` and the inputs are - invalid, correct behavior is not guaranteed. - allow_nan_stats: `Boolean`, default `True`. If `False`, raise an - exception if a statistic (e.g. mean/mode/etc...) is undefined for any - batch member. If `True`, batch members with valid parameters leading to - undefined statistics will return NaN for this statistic. - name: The name to prepend to all ops created by this distribution. + concentration: Floating point tensor, the concentration params of the + distribution(s). Must contain only positive values. + rate: Floating point tensor, the inverse scale params of the + distribution(s). Must contain only positive values. + validate_args: Python `Boolean`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + allow_nan_stats: Python `Boolean`, default `True`. When `True`, statistics + (e.g., mean, mode, variance) use the value "`NaN`" to indicate the + result is undefined. When `False`, an exception is raised if one or + more of the statistic's batch members are undefined. + name: `String` name prefixed to Ops created by this class. + Raises: - TypeError: if `alpha` and `beta` are different dtypes. + TypeError: if `concentration` and `rate` are different dtypes. """ parameters = locals() parameters.pop("self") - with ops.name_scope(name, values=[alpha, beta]) as ns: + with ops.name_scope(name, values=[concentration, rate]) as ns: with ops.control_dependencies([ - check_ops.assert_positive(alpha), - check_ops.assert_positive(beta), + check_ops.assert_positive(concentration), + check_ops.assert_positive(rate), ] if validate_args else []): - self._alpha = array_ops.identity(alpha, name="alpha") - self._beta = array_ops.identity(beta, name="beta") + self._concentration = array_ops.identity( + concentration, name="concentration") + self._rate = array_ops.identity(rate, name="rate") + contrib_tensor_util.assert_same_float_dtype( + [self._concentration, self._rate]) super(InverseGamma, self).__init__( - dtype=self._alpha.dtype, + dtype=self._concentration.dtype, validate_args=validate_args, allow_nan_stats=allow_nan_stats, is_continuous=True, reparameterization_type=distribution.NOT_REPARAMETERIZED, parameters=parameters, - graph_parents=[self._alpha, self._beta], + graph_parents=[self._concentration, + self._rate], name=ns) @staticmethod def _param_shapes(sample_shape): return dict( - zip(("alpha", "beta"), ([ops.convert_to_tensor( + zip(("concentration", "rate"), ([ops.convert_to_tensor( sample_shape, dtype=dtypes.int32)] * 2))) @property - def alpha(self): - """Shape parameter.""" - return self._alpha + def concentration(self): + """Concentration parameter.""" + return self._concentration @property - def beta(self): - """Scale parameter.""" - return self._beta + def rate(self): + """Rate parameter.""" + return self._rate def _batch_shape_tensor(self): return array_ops.broadcast_dynamic_shape( - array_ops.shape(self.alpha), array_ops.shape(self.beta)) + array_ops.shape(self.concentration), + array_ops.shape(self.rate)) def _batch_shape(self): return array_ops.broadcast_static_shape( - self.alpha.get_shape(), self.beta.get_shape()) + self.concentration.get_shape(), + self.rate.get_shape()) def _event_shape_tensor(self): return constant_op.constant([], dtype=dtypes.int32) @@ -137,17 +180,19 @@ class InverseGamma(distribution.Distribution): def _event_shape(self): return tensor_shape.scalar() + @distribution_util.AppendDocstring( + """Note: See `tf.random_gamma` docstring for sampling details and + caveats.""") def _sample_n(self, n, seed=None): - """See the documentation for tf.random_gamma for more details.""" - return 1. / random_ops.random_gamma([n], self.alpha, beta=self.beta, - dtype=self.dtype, seed=seed) + return 1. / random_ops.random_gamma( + shape=[n], + alpha=self.concentration, + beta=self.rate, + dtype=self.dtype, + seed=seed) def _log_prob(self, x): - x = control_flow_ops.with_dependencies([check_ops.assert_positive(x)] if - self.validate_args else [], x) - return (self.alpha * math_ops.log(self.beta) - - math_ops.lgamma(self.alpha) - - (self.alpha + 1.) * math_ops.log(x) - self.beta / x) + return self._log_unnormalized_prob(x) - self._log_normalization() def _prob(self, x): return math_ops.exp(self._log_prob(x)) @@ -156,84 +201,100 @@ class InverseGamma(distribution.Distribution): return math_ops.log(self._cdf(x)) def _cdf(self, x): - x = control_flow_ops.with_dependencies([check_ops.assert_positive(x)] if - self.validate_args else [], x) + x = self._maybe_assert_valid_sample(x) # Note that igammac returns the upper regularized incomplete gamma # function Q(a, x), which is what we want for the CDF. - return math_ops.igammac(self.alpha, self.beta / x) + return math_ops.igammac(self.concentration, self.rate / x) - @distribution_util.AppendDocstring( - """This is defined to be + def _log_unnormalized_prob(self, x): + x = self._maybe_assert_valid_sample(x) + return -(1. + self.concentration) * math_ops.log(x) - self.rate / x - ``` - entropy = alpha - log(beta) + log(Gamma(alpha)) - + (1-alpha)digamma(alpha) - ``` + def _log_normalization(self): + return (math_ops.lgamma(self.concentration) + - self.concentration * math_ops.log(self.rate)) - where digamma(alpha) is the digamma function.""") def _entropy(self): - return (self.alpha + - math_ops.log(self.beta) + - math_ops.lgamma(self.alpha) - - (1. + self.alpha) * math_ops.digamma(self.alpha)) + return (self.concentration + + math_ops.log(self.rate) + + math_ops.lgamma(self.concentration) + - ((1. + self.concentration) * + math_ops.digamma(self.concentration))) @distribution_util.AppendDocstring( - """The mean of an inverse gamma distribution is `beta / (alpha - 1)`, - when `alpha > 1`, and `NaN` otherwise. If `self.allow_nan_stats` is - `False`, an exception will be raised rather than returning `NaN`""") + """The mean of an inverse gamma distribution is + `rate / (concentration - 1)`, when `concentration > 1`, and `NaN` + otherwise. If `self.allow_nan_stats` is `False`, an exception will be + raised rather than returning `NaN`""") def _mean(self): - mean = self.beta / (self.alpha - 1.) + mean = self.rate / (self.concentration - 1.) if self.allow_nan_stats: - nan = np.array(np.nan, dtype=self.dtype.as_numpy_dtype()) - return array_ops.where( - self.alpha > 1., mean, - array_ops.fill(self.batch_shape_tensor(), nan, name="nan")) + nan = array_ops.fill( + self.batch_shape_tensor(), + np.array(np.nan, dtype=self.dtype.as_numpy_dtype()), + name="nan") + return array_ops.where(self.concentration > 1., mean, nan) else: return control_flow_ops.with_dependencies([ check_ops.assert_less( - array_ops.ones((), self.dtype), self.alpha, - message="mean not defined for components of self.alpha <= 1"), + array_ops.ones([], self.dtype), self.concentration, + message="mean undefined when any concentration <= 1"), ], mean) @distribution_util.AppendDocstring( - """Variance for inverse gamma is defined only for `alpha > 2`. If + """Variance for inverse gamma is defined only for `concentration > 2`. If `self.allow_nan_stats` is `False`, an exception will be raised rather than returning `NaN`.""") def _variance(self): - var = (math_ops.square(self.beta) / - (math_ops.square(self.alpha - 1.) * (self.alpha - 2.))) + var = (math_ops.square(self.rate) + / math_ops.square(self.concentration - 1.) + / (self.concentration - 2.)) if self.allow_nan_stats: - nan = np.array(np.nan, dtype=self.dtype.as_numpy_dtype()) - return array_ops.where( - self.alpha > 2., var, - array_ops.fill(self.batch_shape_tensor(), nan, name="nan")) + nan = array_ops.fill( + self.batch_shape_tensor(), + np.array(np.nan, dtype=self.dtype.as_numpy_dtype()), + name="nan") + return array_ops.where(self.concentration > 2., var, nan) else: return control_flow_ops.with_dependencies([ check_ops.assert_less( - constant_op.constant(2., dtype=self.dtype), self.alpha, - message="variance not defined for components of alpha <= 2"), + constant_op.constant(2., dtype=self.dtype), + self.concentration, + message="variance undefined when any concentration <= 2"), ], var) + @distribution_util.AppendDocstring( + """The mode of an inverse gamma distribution is `rate / (concentration + + 1)`.""") def _mode(self): - """The mode of an inverse gamma distribution is `beta / (alpha + 1)`.""" - return self.beta / (self.alpha + 1.) + return self.rate / (1. + self.concentration) + + def _maybe_assert_valid_sample(self, x): + contrib_tensor_util.assert_same_float_dtype( + tensors=[x], dtype=self.dtype) + if not self.validate_args: + return x + return control_flow_ops.with_dependencies([ + check_ops.assert_positive(x), + ], x) -class InverseGammaWithSoftplusAlphaBeta(InverseGamma): - """Inverse Gamma with softplus applied to `alpha` and `beta`.""" +class InverseGammaWithSoftplusConcentrationRate(InverseGamma): + """`InverseGamma` with softplus of `concentration` and `rate`.""" def __init__(self, - alpha, - beta, + concentration, + rate, validate_args=False, allow_nan_stats=True, - name="InverseGammaWithSoftplusAlphaBeta"): + name="InverseGammaWithSoftplusConcentrationRate"): parameters = locals() parameters.pop("self") - with ops.name_scope(name, values=[alpha, beta]) as ns: - super(InverseGammaWithSoftplusAlphaBeta, self).__init__( - alpha=nn.softplus(alpha, name="softplus_alpha"), - beta=nn.softplus(beta, name="softplus_gamma"), + with ops.name_scope(name, values=[concentration, rate]) as ns: + super(InverseGammaWithSoftplusConcentrationRate, self).__init__( + concentration=nn.softplus(concentration, + name="softplus_concentration"), + rate=nn.softplus(rate, name="softplus_rate"), validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=ns) diff --git a/tensorflow/contrib/distributions/python/ops/poisson.py b/tensorflow/contrib/distributions/python/ops/poisson.py index cabb99d106..fa516141fb 100644 --- a/tensorflow/contrib/distributions/python/ops/poisson.py +++ b/tensorflow/contrib/distributions/python/ops/poisson.py @@ -30,9 +30,14 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops -_poisson_prob_note = """ -Note thet the input value must be a non-negative floating point tensor with -dtype `dtype` and whose shape can be broadcast with `self.lam`. `x` is only +__all__ = [ + "Poisson", +] + + +_poisson_sample_note = """ +Note that the input value must be a non-negative floating point tensor with +dtype `dtype` and whose shape can be broadcast with `self.rate`. `x` is only legal if it is non-negative and its components are equal to integer values. """ @@ -40,63 +45,67 @@ legal if it is non-negative and its components are equal to integer values. class Poisson(distribution.Distribution): """Poisson distribution. - The Poisson distribution is parameterized by `lam`, the rate parameter. + The Poisson distribution is parameterized by an event `rate` parameter. - The pmf of this distribution is: + #### Mathematical Details - ``` + The probability mass function (pmf) is, - pmf(k) = e^(-lam) * lam^k / k!, k >= 0 + ```none + pmf(k; lambda, k >= 0) = (lambda^k / k!) / Z + Z = exp(lambda). ``` + where `rate = lambda` and `Z` is the normalizing constant. + """ def __init__(self, - lam, + rate, validate_args=False, allow_nan_stats=True, name="Poisson"): - """Construct Poisson distributions. + """Initialize a batch of Poisson distributions. Args: - lam: Floating point tensor, the rate parameter of the - distribution(s). `lam` must be positive. - validate_args: `Boolean`, default `False`. Whether to assert that - `lam > 0` as well as inputs to `prob` computations are non-negative - integers. If validate_args is `False`, then `prob` computations might - return `NaN`, but can be evaluated at any real value. - allow_nan_stats: `Boolean`, default `True`. If `False`, raise an - exception if a statistic (e.g. mean/mode/etc...) is undefined for any - batch member. If `True`, batch members with valid parameters leading to - undefined statistics will return NaN for this statistic. - name: A name for this distribution. + rate: Floating point tensor, the rate parameter of the + distribution(s). `rate` must be positive. + validate_args: Python `Boolean`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + allow_nan_stats: Python `Boolean`, default `True`. When `True`, statistics + (e.g., mean, mode, variance) use the value "`NaN`" to indicate the + result is undefined. When `False`, an exception is raised if one or + more of the statistic's batch members are undefined. + name: `String` name prefixed to Ops created by this class. """ parameters = locals() parameters.pop("self") - with ops.name_scope(name, values=[lam]) as ns: - with ops.control_dependencies([check_ops.assert_positive(lam)] if + with ops.name_scope(name, values=[rate]) as ns: + with ops.control_dependencies([check_ops.assert_positive(rate)] if validate_args else []): - self._lam = array_ops.identity(lam, name="lam") + self._rate = array_ops.identity(rate, name="rate") super(Poisson, self).__init__( - dtype=self._lam.dtype, + dtype=self._rate.dtype, is_continuous=False, reparameterization_type=distribution.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, - graph_parents=[self._lam], + graph_parents=[self._rate], name=ns) @property - def lam(self): + def rate(self): """Rate parameter.""" - return self._lam + return self._rate def _batch_shape_tensor(self): - return array_ops.shape(self.lam) + return array_ops.shape(self.rate) def _batch_shape(self): - return self.lam.get_shape() + return self.rate.get_shape() def _event_shape_tensor(self): return constant_op.constant([], dtype=dtypes.int32) @@ -104,40 +113,47 @@ class Poisson(distribution.Distribution): def _event_shape(self): return tensor_shape.scalar() - @distribution_util.AppendDocstring(_poisson_prob_note) + @distribution_util.AppendDocstring(_poisson_sample_note) def _log_prob(self, x): - x = self._assert_valid_sample(x, check_integer=True) - return x * math_ops.log(self.lam) - self.lam - math_ops.lgamma(x + 1) + return self._log_unnormalized_prob(x) - self._log_normalization() - @distribution_util.AppendDocstring(_poisson_prob_note) + @distribution_util.AppendDocstring(_poisson_sample_note) def _prob(self, x): return math_ops.exp(self._log_prob(x)) + @distribution_util.AppendDocstring(_poisson_sample_note) def _log_cdf(self, x): return math_ops.log(self.cdf(x)) + @distribution_util.AppendDocstring(_poisson_sample_note) def _cdf(self, x): x = self._assert_valid_sample(x, check_integer=False) - return math_ops.igammac(math_ops.floor(x + 1), self.lam) + return math_ops.igammac(math_ops.floor(x + 1), self.rate) + + def _log_normalization(self): + return self.rate + + def _log_unnormalized_prob(self, x): + x = self._assert_valid_sample(x, check_integer=True) + return x * math_ops.log(self.rate) - math_ops.lgamma(x + 1) def _mean(self): - return array_ops.identity(self.lam) + return array_ops.identity(self.rate) def _variance(self): - return array_ops.identity(self.lam) + return array_ops.identity(self.rate) @distribution_util.AppendDocstring( - """Note that when `lam` is an integer, there are actually two modes. - Namely, `lam` and `lam - 1` are both modes. Here we return - only the larger of the two modes.""") + """Note: when `rate` is an integer, there are actually two modes: `rate` + and `rate - 1`. In this case we return the larger, i.e., `rate`.""") def _mode(self): - return math_ops.floor(self.lam) + return math_ops.floor(self.rate) def _assert_valid_sample(self, x, check_integer=True): - if not self.validate_args: return x - with ops.name_scope("check_x", values=[x]): - dependencies = [check_ops.assert_non_negative(x)] - if check_integer: - dependencies += [distribution_util.assert_integer_form( - x, message="x has non-integer components.")] - return control_flow_ops.with_dependencies(dependencies, x) + if not self.validate_args: + return x + dependencies = [check_ops.assert_non_negative(x)] + if check_integer: + dependencies += [distribution_util.assert_integer_form( + x, message="x has non-integer components.")] + return control_flow_ops.with_dependencies(dependencies, x) diff --git a/tensorflow/contrib/distributions/python/ops/wishart.py b/tensorflow/contrib/distributions/python/ops/wishart.py index 24e48f0a8d..5f5c74184a 100644 --- a/tensorflow/contrib/distributions/python/ops/wishart.py +++ b/tensorflow/contrib/distributions/python/ops/wishart.py @@ -38,6 +38,12 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops +__all__ = [ + "WishartCholesky", + "WishartFull", +] + + class _WishartOperatorPD(distribution.Distribution): """The matrix Wishart distribution on positive definite matrices. @@ -45,22 +51,22 @@ class _WishartOperatorPD(distribution.Distribution): an instance of `OperatorPDBase`, which provides matrix-free access to a symmetric positive definite operator, which defines the scale matrix. - #### Mathematical details. + #### Mathematical Details - The PDF of this distribution is, + The probability density function (pdf) is, + ```none + pdf(X; df, scale) = det(X)**(0.5 (df-k-1)) exp(-0.5 tr[inv(scale) X]) / Z + Z = 2**(0.5 df k) |det(scale)|**(0.5 df) Gamma_k(0.5 df) ``` - f(X) = det(X)^(0.5 (df-k-1)) exp(-0.5 tr[inv(scale) X]) / B(scale, df) - ``` - - where `df >= k` denotes the degrees of freedom, `scale` is a symmetric, pd, - `k x k` matrix, and the normalizing constant `B(scale, df)` is given by: - ``` - B(scale, df) = 2^(0.5 df k) |det(scale)|^(0.5 df) Gamma_k(0.5 df) - ``` + where: - where `Gamma_k` is the multivariate Gamma function. + * `df >= k` denotes the degrees of freedom, + * `scale` is a symmetric, positive definite, `k x k` matrix, + * `Z` is the normalizing constant, and, + * `Gamma_k` is the [multivariate Gamma function]( + https://en.wikipedia.org/wiki/Multivariate_gamma_function). #### Examples @@ -86,19 +92,21 @@ class _WishartOperatorPD(distribution.Distribution): Cholesky factored matrix. Example `log_prob` input takes a Cholesky and `sample_n` returns a Cholesky when `cholesky_input_output_matrices=True`. - validate_args: `Boolean`, default `False`. Whether to validate input with - asserts. If `validate_args` is `False`, and the inputs are invalid, - correct behavior is not guaranteed. - allow_nan_stats: `Boolean`, default `True`. If `False`, raise an - exception if a statistic (e.g., mean, mode) is undefined for any batch - member. If True, batch members with valid parameters leading to - undefined statistics will return `NaN` for this statistic. - name: The name to give Ops created by the initializer. + validate_args: Python `Boolean`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + allow_nan_stats: Python `Boolean`, default `True`. When `True`, statistics + (e.g., mean, mode, variance) use the value "`NaN`" to indicate the + result is undefined. When `False`, an exception is raised if one or + more of the statistic's batch members are undefined. + name: `String` name prefixed to Ops created by this class. Raises: TypeError: if scale is not floating-type TypeError: if scale.dtype != df.dtype - ValueError: if df < k, where scale operator event shape is `(k, k)` + ValueError: if df < k, where scale operator event shape is + `(k, k)` """ parameters = locals() parameters.pop("self") @@ -111,7 +119,9 @@ class _WishartOperatorPD(distribution.Distribution): scale_operator_pd.dtype) self._scale_operator_pd = scale_operator_pd self._df = ops.convert_to_tensor( - df, dtype=scale_operator_pd.dtype, name="df") + df, + dtype=scale_operator_pd.dtype, + name="df") contrib_tensor_util.assert_same_float_dtype( (self._df, self._scale_operator_pd)) if (self._scale_operator_pd.get_shape().ndims is None or @@ -127,19 +137,22 @@ class _WishartOperatorPD(distribution.Distribution): dim_val = tensor_util.constant_value(self._dimension) if df_val is not None and dim_val is not None: df_val = np.asarray(df_val) - if not df_val.shape: df_val = (df_val,) + if not df_val.shape: + df_val = [df_val] if any(df_val < dim_val): raise ValueError( - "Degrees of freedom (df = %s) cannot be less than dimension of " - "scale matrix (scale.dimension = %s)" + "Degrees of freedom (df = %s) cannot be less than " + "dimension of scale matrix (scale.dimension = %s)" % (df_val, dim_val)) elif validate_args: assertions = check_ops.assert_less_equal( self._dimension, self._df, - message=("Degrees of freedom (df = %s) cannot be less than " - "dimension of scale matrix (scale.dimension = %s)" % + message=("Degrees of freedom (df = %s) cannot be " + "less than dimension of scale matrix " + "(scale.dimension = %s)" % (self._dimension, self._df))) - self._df = control_flow_ops.with_dependencies([assertions], self._df) + self._df = control_flow_ops.with_dependencies( + [assertions], self._df) super(_WishartOperatorPD, self).__init__( dtype=self._scale_operator_pd.dtype, validate_args=validate_args, @@ -321,7 +334,7 @@ class _WishartOperatorPD(distribution.Distribution): # Complexity: O(nbk^2) log_prob = ((self.df - self.dimension - 1.) * half_log_det_x - 0.5 * trace_scale_inv_x - - self.log_normalizing_constant()) + self.log_normalization()) # Set shape hints. # Try to merge what we know from the input then what we know from the @@ -349,7 +362,8 @@ class _WishartOperatorPD(distribution.Distribution): def _mean(self): if self.cholesky_input_output_matrices: - return math_ops.sqrt(self.df) * self.scale_operator_pd.sqrt_to_dense() + return (math_ops.sqrt(self.df) + * self.scale_operator_pd.sqrt_to_dense()) return self.df * self.scale_operator_pd.to_dense() def _variance(self): @@ -384,7 +398,7 @@ class _WishartOperatorPD(distribution.Distribution): self.dimension * math.log(2.) + self.scale_operator_pd.log_det()) - def log_normalizing_constant(self, name="log_normalizing_constant"): + def log_normalization(self, name="log_normalization"): """Computes the log normalizing constant, log(Z).""" with self._name_scope(name): return (self.df * self.scale_operator_pd.sqrt_log_det() + @@ -429,22 +443,21 @@ class WishartCholesky(_WishartOperatorPD): another O(nbk^3) operation since most uses of Wishart will also use the Cholesky factorization. - #### Mathematical details. - - The PDF of this distribution is, - - ``` - f(X) = det(X)^(0.5 (df-k-1)) exp(-0.5 tr[inv(scale) X]) / B(scale, df) - ``` + #### Mathematical Details - where `df >= k` denotes the degrees of freedom, `scale` is a symmetric, pd, - `k x k` matrix, and the normalizing constant `B(scale, df)` is given by: + The probability density function (pdf) is, - ``` - B(scale, df) = 2^(0.5 df k) |det(scale)|^(0.5 df) Gamma_k(0.5 df) + ```none + pdf(X; df, scale) = det(X)**(0.5 (df-k-1)) exp(-0.5 tr[inv(scale) X]) / Z + Z = 2**(0.5 df k) |det(scale)|**(0.5 df) Gamma_k(0.5 df) ``` - where `Gamma_k` is the multivariate Gamma function. + where: + * `df >= k` denotes the degrees of freedom, + * `scale` is a symmetric, positive definite, `k x k` matrix, + * `Z` is the normalizing constant, and, + * `Gamma_k` is the [multivariate Gamma function]( + https://en.wikipedia.org/wiki/Multivariate_gamma_function). #### Examples @@ -499,14 +512,15 @@ class WishartCholesky(_WishartOperatorPD): Cholesky factored matrix. Example `log_prob` input takes a Cholesky and `sample_n` returns a Cholesky when `cholesky_input_output_matrices=True`. - validate_args: `Boolean`, default `False`. Whether to validate input - with asserts. If `validate_args` is `False`, and the inputs are invalid, - correct behavior is not guaranteed. - allow_nan_stats: `Boolean`, default `True`. If `False`, raise an - exception if a statistic (e.g., mean, mode) is undefined for any batch - member. If True, batch members with valid parameters leading to - undefined statistics will return `NaN` for this statistic. - name: The name scope to give class member ops. + validate_args: Python `Boolean`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + allow_nan_stats: Python `Boolean`, default `True`. When `True`, statistics + (e.g., mean, mode, variance) use the value "`NaN`" to indicate the + result is undefined. When `False`, an exception is raised if one or + more of the statistic's batch members are undefined. + name: `String` name prefixed to Ops created by this class. """ parameters = locals() parameters.pop("self") @@ -531,22 +545,21 @@ class WishartFull(_WishartOperatorPD): Evaluation of the pdf, determinant, and sampling are all `O(k^3)` operations where `(k, k)` is the event space shape. - #### Mathematical details. - - The PDF of this distribution is, + #### Mathematical Details - ``` - f(X) = det(X)^(0.5 (df-k-1)) exp(-0.5 tr[inv(scale) X]) / B(scale, df) - ``` + The probability density function (pdf) is, - where `df >= k` denotes the degrees of freedom, `scale` is a symmetric, pd, - `k x k` matrix, and the normalizing constant `B(scale, df)` is given by: - - ``` - B(scale, df) = 2^(0.5 df k) |det(scale)|^(0.5 df) Gamma_k(0.5 df) + ```none + pdf(X; df, scale) = det(X)**(0.5 (df-k-1)) exp(-0.5 tr[inv(scale) X]) / Z + Z = 2**(0.5 df k) |det(scale)|**(0.5 df) Gamma_k(0.5 df) ``` - where `Gamma_k` is the multivariate Gamma function. + where: + * `df >= k` denotes the degrees of freedom, + * `scale` is a symmetric, positive definite, `k x k` matrix, + * `Z` is the normalizing constant, and, + * `Gamma_k` is the [multivariate Gamma function]( + https://en.wikipedia.org/wiki/Multivariate_gamma_function). #### Examples @@ -600,14 +613,15 @@ class WishartFull(_WishartOperatorPD): Cholesky factored matrix. Example `log_prob` input takes a Cholesky and `sample_n` returns a Cholesky when `cholesky_input_output_matrices=True`. - validate_args: `Boolean`, default `False`. Whether to validate input with - asserts. If `validate_args` is `False`, and the inputs are invalid, - correct behavior is not guaranteed. - allow_nan_stats: `Boolean`, default `True`. If `False`, raise an - exception if a statistic (e.g., mean, mode) is undefined for any batch - member. If True, batch members with valid parameters leading to - undefined statistics will return `NaN` for this statistic. - name: The name scope to give class member ops. + validate_args: Python `Boolean`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + allow_nan_stats: Python `Boolean`, default `True`. When `True`, statistics + (e.g., mean, mode, variance) use the value "`NaN`" to indicate the + result is undefined. When `False`, an exception is raised if one or + more of the statistic's batch members are undefined. + name: `String` name prefixed to Ops created by this class. """ parameters = locals() parameters.pop("self") |