aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Joshua V. Dillon <jvdillon@google.com>2017-01-31 17:18:37 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-31 17:39:00 -0800
commit0e9cebfc8c35781127ff91246598b87ca8ce0aa5 (patch)
tree189d62189e802b48e9950852aa099233f7b514cf
parentbbadfff5834f16d6705b09bce26c2b0972c5dc70 (diff)
BREAKING CHANGE: Standardize "concentration", "rate", "total_count" distribution arguments.
BUGFIX: Correct undefined mode in dirichlet.mode. BUGFIX: Correct broadcasting in dirichletmultinomial.mean. Change: 146187147
-rw-r--r--tensorflow/contrib/distributions/__init__.py8
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/beta_test.py49
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/chi2_test.py3
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/dirichlet_multinomial_test.py44
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/dirichlet_test.py52
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/exponential_test.py21
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/gamma_test.py59
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/inverse_gamma_test.py50
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py34
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/quantized_distribution_test.py4
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py47
-rw-r--r--tensorflow/contrib/distributions/python/ops/beta.py395
-rw-r--r--tensorflow/contrib/distributions/python/ops/chi2.py62
-rw-r--r--tensorflow/contrib/distributions/python/ops/dirichlet.py308
-rw-r--r--tensorflow/contrib/distributions/python/ops/dirichlet_multinomial.py320
-rw-r--r--tensorflow/contrib/distributions/python/ops/exponential.py102
-rw-r--r--tensorflow/contrib/distributions/python/ops/gamma.py267
-rw-r--r--tensorflow/contrib/distributions/python/ops/inverse_gamma.py257
-rw-r--r--tensorflow/contrib/distributions/python/ops/poisson.py110
-rw-r--r--tensorflow/contrib/distributions/python/ops/wishart.py154
20 files changed, 1339 insertions, 1007 deletions
diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py
index 00946a1f52..9ba3a60044 100644
--- a/tensorflow/contrib/distributions/__init__.py
+++ b/tensorflow/contrib/distributions/__init__.py
@@ -30,16 +30,16 @@ initialized with parameters that define the distributions.
@@Bernoulli
@@BernoulliWithSigmoidProbs
@@Beta
-@@BetaWithSoftplusAB
+@@BetaWithSoftplusConcentration
@@Categorical
@@Chi2
@@Chi2WithAbsDf
@@Exponential
-@@ExponentialWithSoftplusLam
+@@ExponentialWithSoftplusRate
@@Gamma
-@@GammaWithSoftplusAlphaBeta
+@@GammaWithSoftplusConcentrationRate
@@InverseGamma
-@@InverseGammaWithSoftplusAlphaBeta
+@@InverseGammaWithSoftplusConcentrationRate
@@Laplace
@@LaplaceWithSoftplusScale
@@Normal
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/beta_test.py b/tensorflow/contrib/distributions/python/kernel_tests/beta_test.py
index 17c3320d70..f524986cec 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/beta_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/beta_test.py
@@ -69,16 +69,16 @@ class BetaTest(test.TestCase):
b = [[2., 4, 3]]
with self.test_session():
dist = beta_lib.Beta(a, b)
- self.assertEqual([1, 3], dist.a.get_shape())
- self.assertAllClose(a, dist.a.eval())
+ self.assertEqual([1, 3], dist.concentration1.get_shape())
+ self.assertAllClose(a, dist.concentration1.eval())
def testBetaProperty(self):
a = [[1., 2, 3]]
b = [[2., 4, 3]]
with self.test_session():
dist = beta_lib.Beta(a, b)
- self.assertEqual([1, 3], dist.b.get_shape())
- self.assertAllClose(b, dist.b.eval())
+ self.assertEqual([1, 3], dist.concentration0.get_shape())
+ self.assertAllClose(b, dist.concentration0.eval())
def testPdfXProper(self):
a = [[1., 2, 3]]
@@ -88,11 +88,11 @@ class BetaTest(test.TestCase):
dist.prob([.1, .3, .6]).eval()
dist.prob([.2, .3, .5]).eval()
# Either condition can trigger.
- with self.assertRaisesOpError("(Condition x > 0.*|Condition x < y.*)"):
- dist.prob([-1., 1, 1]).eval()
- with self.assertRaisesOpError("Condition x.*"):
- dist.prob([0., 1, 1]).eval()
- with self.assertRaisesOpError("Condition x < y.*"):
+ with self.assertRaisesOpError("sample must be positive"):
+ dist.prob([-1., 0.1, 0.5]).eval()
+ with self.assertRaisesOpError("sample must be positive"):
+ dist.prob([0., 0.1, 0.5]).eval()
+ with self.assertRaisesOpError("sample must be no larger than `1`"):
dist.prob([.1, .2, 1.2]).eval()
def testPdfTwoBatches(self):
@@ -247,8 +247,7 @@ class BetaTest(test.TestCase):
stats.kstest(
# Beta is a univariate distribution.
sample_values,
- stats.beta(
- a=1., b=2.).cdf)[0],
+ stats.beta(a=1., b=2.).cdf)[0],
0.01)
# The standard error of the sample mean is 1 / (sqrt(18 * n))
self.assertAllClose(
@@ -264,11 +263,15 @@ class BetaTest(test.TestCase):
n_val = 100
random_seed.set_random_seed(654321)
- beta1 = beta_lib.Beta(a=a_val, b=b_val, name="beta1")
+ beta1 = beta_lib.Beta(concentration1=a_val,
+ concentration0=b_val,
+ name="beta1")
samples1 = beta1.sample(n_val, seed=123456).eval()
random_seed.set_random_seed(654321)
- beta2 = beta_lib.Beta(a=a_val, b=b_val, name="beta2")
+ beta2 = beta_lib.Beta(concentration1=a_val,
+ concentration0=b_val,
+ name="beta2")
samples2 = beta2.sample(n_val, seed=123456).eval()
self.assertAllClose(samples1, samples2)
@@ -312,12 +315,12 @@ class BetaTest(test.TestCase):
self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x)
self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0)
- def testBetaWithSoftplusAB(self):
+ def testBetaWithSoftplusConcentration(self):
with self.test_session():
a, b = -4.2, -9.1
- dist = beta_lib.BetaWithSoftplusAB(a, b)
- self.assertAllClose(nn_ops.softplus(a).eval(), dist.a.eval())
- self.assertAllClose(nn_ops.softplus(b).eval(), dist.b.eval())
+ dist = beta_lib.BetaWithSoftplusConcentration(a, b)
+ self.assertAllClose(nn_ops.softplus(a).eval(), dist.concentration1.eval())
+ self.assertAllClose(nn_ops.softplus(b).eval(), dist.concentration0.eval())
def testBetaBetaKL(self):
with self.test_session() as sess:
@@ -326,16 +329,18 @@ class BetaTest(test.TestCase):
b1 = 6.0 * np.random.random(size=shape) + 1e-4
a2 = 6.0 * np.random.random(size=shape) + 1e-4
b2 = 6.0 * np.random.random(size=shape) + 1e-4
- # Take inverse softplus of values to test BetaWithSoftplusAB
+ # Take inverse softplus of values to test BetaWithSoftplusConcentration
a1_sp = np.log(np.exp(a1) - 1.0)
b1_sp = np.log(np.exp(b1) - 1.0)
a2_sp = np.log(np.exp(a2) - 1.0)
b2_sp = np.log(np.exp(b2) - 1.0)
- d1 = beta_lib.Beta(a=a1, b=b1)
- d2 = beta_lib.Beta(a=a2, b=b2)
- d1_sp = beta_lib.BetaWithSoftplusAB(a=a1_sp, b=b1_sp)
- d2_sp = beta_lib.BetaWithSoftplusAB(a=a2_sp, b=b2_sp)
+ d1 = beta_lib.Beta(concentration1=a1, concentration0=b1)
+ d2 = beta_lib.Beta(concentration1=a2, concentration0=b2)
+ d1_sp = beta_lib.BetaWithSoftplusConcentration(concentration1=a1_sp,
+ concentration0=b1_sp)
+ d2_sp = beta_lib.BetaWithSoftplusConcentration(concentration1=a2_sp,
+ concentration0=b2_sp)
kl_expected = (special.betaln(a2, b2) - special.betaln(a1, b1) +
(a1 - a2) * special.digamma(a1) +
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/chi2_test.py b/tensorflow/contrib/distributions/python/kernel_tests/chi2_test.py
index 4157f0af25..75d48791ec 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/chi2_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/chi2_test.py
@@ -88,7 +88,8 @@ class Chi2Test(test.TestCase):
df_v = np.array([-1.3, -3.2, 5], dtype=np.float64)
chi2 = chi2_lib.Chi2WithAbsDf(df=df_v)
self.assertAllClose(
- math_ops.floor(math_ops.abs(df_v)).eval(), chi2.df.eval())
+ math_ops.floor(math_ops.abs(df_v)).eval(),
+ chi2.df.eval())
if __name__ == "__main__":
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_multinomial_test.py b/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_multinomial_test.py
index 32230c8a06..235ce20945 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_multinomial_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_multinomial_test.py
@@ -55,15 +55,15 @@ class DirichletMultinomialTest(test.TestCase):
n = [[5.]]
with self.test_session():
dist = ds.DirichletMultinomial(n, alpha)
- self.assertEqual([1, 1], dist.n.get_shape())
- self.assertAllClose(n, dist.n.eval())
+ self.assertEqual([1, 1], dist.total_count.get_shape())
+ self.assertAllClose(n, dist.total_count.eval())
def testAlphaProperty(self):
alpha = [[1., 2, 3]]
with self.test_session():
dist = ds.DirichletMultinomial(1, alpha)
- self.assertEqual([1, 3], dist.alpha.get_shape())
- self.assertAllClose(alpha, dist.alpha.eval())
+ self.assertEqual([1, 3], dist.concentration.get_shape())
+ self.assertAllClose(alpha, dist.concentration.eval())
def testPmfNandCountsAgree(self):
alpha = [[1., 2, 3]]
@@ -72,9 +72,10 @@ class DirichletMultinomialTest(test.TestCase):
dist = ds.DirichletMultinomial(n, alpha, validate_args=True)
dist.prob([2., 3, 0]).eval()
dist.prob([3., 0, 2]).eval()
- with self.assertRaisesOpError("Condition x >= 0.*"):
+ with self.assertRaisesOpError("counts must be non-negative"):
dist.prob([-1., 4, 2]).eval()
- with self.assertRaisesOpError("counts do not sum to n"):
+ with self.assertRaisesOpError(
+ "counts last-dimension must sum to `self.total_count`"):
dist.prob([3., 3, 0]).eval()
def testPmfNonIntegerCounts(self):
@@ -86,7 +87,8 @@ class DirichletMultinomialTest(test.TestCase):
dist.prob([3., 0, 2]).eval()
dist.prob([3.0, 0, 2.0]).eval()
# Both equality and integer checking fail.
- with self.assertRaisesOpError("Condition x == y.*"):
+ with self.assertRaisesOpError(
+ "counts cannot contain fractional components"):
dist.prob([1.0, 2.5, 1.5]).eval()
dist = ds.DirichletMultinomial(n, alpha, validate_args=False)
dist.prob([1., 2., 3.]).eval()
@@ -138,7 +140,7 @@ class DirichletMultinomialTest(test.TestCase):
dist = ds.DirichletMultinomial([1.], alpha)
pmf = dist.prob(counts)
self.assertAllClose([1 / 3., 2 / 3.], pmf.eval())
- self.assertEqual((2), pmf.get_shape())
+ self.assertAllEqual([2], pmf.get_shape())
def testPmfAlphaStretchedInBroadcastWhenLowerRank(self):
# The probabilities of one vote falling into class k is the mean for class
@@ -148,7 +150,7 @@ class DirichletMultinomialTest(test.TestCase):
counts = [[1., 0], [0., 1]]
pmf = ds.DirichletMultinomial(1., alpha).prob(counts)
self.assertAllClose([1 / 3., 2 / 3.], pmf.eval())
- self.assertEqual((2), pmf.get_shape())
+ self.assertAllEqual([2], pmf.get_shape())
def testPmfCountsStretchedInBroadcastWhenSameRank(self):
# The probabilities of one vote falling into class k is the mean for class
@@ -158,7 +160,7 @@ class DirichletMultinomialTest(test.TestCase):
counts = [[1., 0]]
pmf = ds.DirichletMultinomial([1., 1.], alpha).prob(counts)
self.assertAllClose([1 / 3., 2 / 5.], pmf.eval())
- self.assertEqual((2), pmf.get_shape())
+ self.assertAllEqual([2], pmf.get_shape())
def testPmfCountsStretchedInBroadcastWhenLowerRank(self):
# The probabilities of one vote falling into class k is the mean for class
@@ -168,7 +170,7 @@ class DirichletMultinomialTest(test.TestCase):
counts = [1., 0]
pmf = ds.DirichletMultinomial(1., alpha).prob(counts)
self.assertAllClose([1 / 3., 2 / 5.], pmf.eval())
- self.assertEqual((2), pmf.get_shape())
+ self.assertAllEqual([2], pmf.get_shape())
def testPmfForOneVoteIsTheMeanWithOneRecordInput(self):
# The probabilities of one vote falling into class k is the mean for class
@@ -176,15 +178,15 @@ class DirichletMultinomialTest(test.TestCase):
alpha = [1., 2, 3]
with self.test_session():
for class_num in range(3):
- counts = np.zeros((3), dtype=np.float32)
+ counts = np.zeros([3], dtype=np.float32)
counts[class_num] = 1
dist = ds.DirichletMultinomial(1., alpha)
mean = dist.mean().eval()
pmf = dist.prob(counts).eval()
self.assertAllClose(mean[class_num], pmf)
- self.assertTupleEqual((3,), mean.shape)
- self.assertTupleEqual((), pmf.shape)
+ self.assertAllEqual([3], mean.shape)
+ self.assertAllEqual([], pmf.shape)
def testMeanDoubleTwoVotes(self):
# The probabilities of two votes falling into class k for
@@ -193,9 +195,9 @@ class DirichletMultinomialTest(test.TestCase):
alpha = [1., 2, 3]
with self.test_session():
for class_num in range(3):
- counts_one = np.zeros((3), dtype=np.float32)
+ counts_one = np.zeros([3], dtype=np.float32)
counts_one[class_num] = 1.
- counts_two = np.zeros((3), dtype=np.float32)
+ counts_two = np.zeros([3], dtype=np.float32)
counts_two[class_num] = 2
dist1 = ds.DirichletMultinomial(1., alpha)
@@ -205,7 +207,7 @@ class DirichletMultinomialTest(test.TestCase):
mean2 = dist2.mean().eval()
self.assertAllClose(mean2[class_num], 2 * mean1[class_num])
- self.assertTupleEqual((3,), mean1.shape)
+ self.assertAllEqual([3], mean1.shape)
def testCovarianceFromSampling(self):
# We will test mean, cov, var, stddev on a DirichletMultinomial constructed
@@ -279,7 +281,7 @@ class DirichletMultinomialTest(test.TestCase):
covariance = dist.covariance()
expected_covariance = n * (n + alpha_0) / (1 + alpha_0) * shared_matrix
- self.assertEqual((2, 2), covariance.get_shape())
+ self.assertEqual([2, 2], covariance.get_shape())
self.assertAllClose(expected_covariance, covariance.eval())
def testCovarianceNAlphaBroadcast(self):
@@ -415,7 +417,8 @@ class DirichletMultinomialTest(test.TestCase):
def testSampleUnbiasedNonScalarBatch(self):
with self.test_session() as sess:
dist = ds.DirichletMultinomial(
- n=5., alpha=2. * self._rng.rand(4, 3, 2).astype(np.float32))
+ total_count=5.,
+ concentration=2. * self._rng.rand(4, 3, 2).astype(np.float32))
n = int(3e3)
x = dist.sample(n, seed=0)
sample_mean = math_ops.reduce_mean(x, 0)
@@ -443,7 +446,8 @@ class DirichletMultinomialTest(test.TestCase):
def testSampleUnbiasedScalarBatch(self):
with self.test_session() as sess:
dist = ds.DirichletMultinomial(
- n=5., alpha=2. * self._rng.rand(4).astype(np.float32))
+ total_count=5.,
+ concentration=2. * self._rng.rand(4).astype(np.float32))
n = int(5e3)
x = dist.sample(n, seed=0)
sample_mean = math_ops.reduce_mean(x, 0)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_test.py b/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_test.py
index 919905d245..cd634da09d 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_test.py
@@ -46,12 +46,12 @@ class DirichletTest(test.TestCase):
self.assertEqual(tensor_shape.TensorShape([2]), dist.event_shape)
self.assertEqual(tensor_shape.TensorShape([3, 2]), dist.batch_shape)
- def testAlphaProperty(self):
+ def testConcentrationProperty(self):
alpha = [[1., 2, 3]]
with self.test_session():
dist = dirichlet_lib.Dirichlet(alpha)
- self.assertEqual([1, 3], dist.alpha.get_shape())
- self.assertAllClose(alpha, dist.alpha.eval())
+ self.assertEqual([1, 3], dist.concentration.get_shape())
+ self.assertAllClose(alpha, dist.concentration.eval())
def testPdfXProper(self):
alpha = [[1., 2, 3]]
@@ -60,11 +60,12 @@ class DirichletTest(test.TestCase):
dist.prob([.1, .3, .6]).eval()
dist.prob([.2, .3, .5]).eval()
# Either condition can trigger.
- with self.assertRaisesOpError("Condition x > 0.*|Condition x < y.*"):
- dist.prob([-1., 1, 1]).eval()
- with self.assertRaisesOpError("Condition x > 0.*"):
+ with self.assertRaisesOpError("samples must be positive"):
+ dist.prob([-1., 1.5, 0.5]).eval()
+ with self.assertRaisesOpError("samples must be positive"):
dist.prob([0., .1, .9]).eval()
- with self.assertRaisesOpError("Condition x ~= y.*"):
+ with self.assertRaisesOpError(
+ "sample last-dimension must sum to `1`"):
dist.prob([.1, .2, .8]).eval()
def testPdfZeroBatches(self):
@@ -128,12 +129,12 @@ class DirichletTest(test.TestCase):
self.assertAllClose([1., 3. / 2], pdf.eval())
self.assertEqual((2), pdf.get_shape())
- def testDirichletMean(self):
+ def testMean(self):
with self.test_session():
alpha = [1., 2, 3]
expected_mean = stats.dirichlet.mean(alpha)
- dirichlet = dirichlet_lib.Dirichlet(alpha=alpha)
- self.assertEqual(dirichlet.mean().get_shape(), (3,))
+ dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
+ self.assertEqual(dirichlet.mean().get_shape(), [3])
self.assertAllClose(dirichlet.mean().eval(), expected_mean)
def testCovarianceFromSampling(self):
@@ -172,51 +173,52 @@ class DirichletTest(test.TestCase):
self.assertAllClose(sample_var_, analytic_var, atol=0., rtol=0.03)
self.assertAllClose(sample_stddev_, analytic_stddev, atol=0., rtol=0.02)
- def testDirichletCovariance(self):
+ def testVariance(self):
with self.test_session():
alpha = [1., 2, 3]
denominator = np.sum(alpha)**2 * (np.sum(alpha) + 1)
expected_covariance = np.diag(stats.dirichlet.var(alpha))
expected_covariance += [[0., -2, -3], [-2, 0, -6],
[-3, -6, 0]] / denominator
- dirichlet = dirichlet_lib.Dirichlet(alpha=alpha)
+ dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
self.assertEqual(dirichlet.covariance().get_shape(), (3, 3))
self.assertAllClose(dirichlet.covariance().eval(), expected_covariance)
- def testDirichletMode(self):
+ def testMode(self):
with self.test_session():
alpha = np.array([1.1, 2, 3])
expected_mode = (alpha - 1) / (np.sum(alpha) - 3)
- dirichlet = dirichlet_lib.Dirichlet(alpha=alpha)
- self.assertEqual(dirichlet.mode().get_shape(), (3,))
+ dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
+ self.assertEqual(dirichlet.mode().get_shape(), [3])
self.assertAllClose(dirichlet.mode().eval(), expected_mode)
- def testDirichletModeInvalid(self):
+ def testModeInvalid(self):
with self.test_session():
alpha = np.array([1., 2, 3])
- dirichlet = dirichlet_lib.Dirichlet(alpha=alpha, allow_nan_stats=False)
+ dirichlet = dirichlet_lib.Dirichlet(concentration=alpha,
+ allow_nan_stats=False)
with self.assertRaisesOpError("Condition x < y.*"):
dirichlet.mode().eval()
- def testDirichletModeEnableAllowNanStats(self):
+ def testModeEnableAllowNanStats(self):
with self.test_session():
alpha = np.array([1., 2, 3])
- dirichlet = dirichlet_lib.Dirichlet(alpha=alpha, allow_nan_stats=True)
- expected_mode = (alpha - 1) / (np.sum(alpha) - 3)
- expected_mode[0] = np.nan
+ dirichlet = dirichlet_lib.Dirichlet(concentration=alpha,
+ allow_nan_stats=True)
+ expected_mode = np.zeros_like(alpha) + np.nan
- self.assertEqual(dirichlet.mode().get_shape(), (3,))
+ self.assertEqual(dirichlet.mode().get_shape(), [3])
self.assertAllClose(dirichlet.mode().eval(), expected_mode)
- def testDirichletEntropy(self):
+ def testEntropy(self):
with self.test_session():
alpha = [1., 2, 3]
expected_entropy = stats.dirichlet.entropy(alpha)
- dirichlet = dirichlet_lib.Dirichlet(alpha=alpha)
+ dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
self.assertEqual(dirichlet.entropy().get_shape(), ())
self.assertAllClose(dirichlet.entropy().eval(), expected_entropy)
- def testDirichletSample(self):
+ def testSample(self):
with self.test_session():
alpha = [1., 2]
dirichlet = dirichlet_lib.Dirichlet(alpha)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/exponential_test.py b/tensorflow/contrib/distributions/python/kernel_tests/exponential_test.py
index 5f9a74405a..6171202413 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/exponential_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/exponential_test.py
@@ -35,7 +35,7 @@ class ExponentialTest(test.TestCase):
lam = constant_op.constant([2.0] * batch_size)
lam_v = 2.0
x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
- exponential = exponential_lib.Exponential(lam=lam)
+ exponential = exponential_lib.Exponential(rate=lam)
expected_log_pdf = stats.expon.logpdf(x, scale=1 / lam_v)
log_pdf = exponential.log_prob(x)
@@ -53,7 +53,7 @@ class ExponentialTest(test.TestCase):
lam_v = 2.0
x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
- exponential = exponential_lib.Exponential(lam=lam)
+ exponential = exponential_lib.Exponential(rate=lam)
expected_cdf = stats.expon.cdf(x, scale=1 / lam_v)
cdf = exponential.cdf(x)
@@ -64,7 +64,7 @@ class ExponentialTest(test.TestCase):
with session.Session():
lam_v = np.array([1.0, 4.0, 2.5])
expected_mean = stats.expon.mean(scale=1 / lam_v)
- exponential = exponential_lib.Exponential(lam=lam_v)
+ exponential = exponential_lib.Exponential(rate=lam_v)
self.assertEqual(exponential.mean().get_shape(), (3,))
self.assertAllClose(exponential.mean().eval(), expected_mean)
@@ -72,7 +72,7 @@ class ExponentialTest(test.TestCase):
with session.Session():
lam_v = np.array([1.0, 4.0, 2.5])
expected_variance = stats.expon.var(scale=1 / lam_v)
- exponential = exponential_lib.Exponential(lam=lam_v)
+ exponential = exponential_lib.Exponential(rate=lam_v)
self.assertEqual(exponential.variance().get_shape(), (3,))
self.assertAllClose(exponential.variance().eval(), expected_variance)
@@ -80,7 +80,7 @@ class ExponentialTest(test.TestCase):
with session.Session():
lam_v = np.array([1.0, 4.0, 2.5])
expected_entropy = stats.expon.entropy(scale=1 / lam_v)
- exponential = exponential_lib.Exponential(lam=lam_v)
+ exponential = exponential_lib.Exponential(rate=lam_v)
self.assertEqual(exponential.entropy().get_shape(), (3,))
self.assertAllClose(exponential.entropy().eval(), expected_entropy)
@@ -89,7 +89,7 @@ class ExponentialTest(test.TestCase):
lam = constant_op.constant([3.0, 4.0])
lam_v = [3.0, 4.0]
n = constant_op.constant(100000)
- exponential = exponential_lib.Exponential(lam=lam)
+ exponential = exponential_lib.Exponential(rate=lam)
samples = exponential.sample(n, seed=137)
sample_values = samples.eval()
@@ -107,7 +107,7 @@ class ExponentialTest(test.TestCase):
lam_v = [3.0, 22.0]
lam = constant_op.constant([lam_v] * batch_size)
- exponential = exponential_lib.Exponential(lam=lam)
+ exponential = exponential_lib.Exponential(rate=lam)
n = 100000
samples = exponential.sample(n, seed=138)
@@ -128,11 +128,12 @@ class ExponentialTest(test.TestCase):
stats.expon(scale=1.0 / lam_v[i]).cdf)[0],
0.01)
- def testExponentialWithSoftplusLam(self):
+ def testExponentialWithSoftplusRate(self):
with self.test_session():
lam = [-2.2, -3.4]
- exponential = exponential_lib.ExponentialWithSoftplusLam(lam=lam)
- self.assertAllClose(nn_ops.softplus(lam).eval(), exponential.lam.eval())
+ exponential = exponential_lib.ExponentialWithSoftplusRate(rate=lam)
+ self.assertAllClose(nn_ops.softplus(lam).eval(),
+ exponential.rate.eval())
if __name__ == "__main__":
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/gamma_test.py b/tensorflow/contrib/distributions/python/kernel_tests/gamma_test.py
index c5c071b49d..fd62710237 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/gamma_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/gamma_test.py
@@ -37,7 +37,7 @@ class GammaTest(test.TestCase):
with self.test_session():
alpha = constant_op.constant([3.0] * 5)
beta = constant_op.constant(11.0)
- gamma = gamma_lib.Gamma(alpha=alpha, beta=beta)
+ gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
self.assertEqual(gamma.batch_shape_tensor().eval(), (5,))
self.assertEqual(gamma.batch_shape, tensor_shape.TensorShape([5]))
@@ -52,7 +52,7 @@ class GammaTest(test.TestCase):
alpha_v = 2.0
beta_v = 3.0
x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
- gamma = gamma_lib.Gamma(alpha=alpha, beta=beta)
+ gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
log_pdf = gamma.log_prob(x)
self.assertEqual(log_pdf.get_shape(), (6,))
@@ -70,7 +70,7 @@ class GammaTest(test.TestCase):
alpha_v = np.array([2.0, 4.0])
beta_v = np.array([3.0, 4.0])
x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
- gamma = gamma_lib.Gamma(alpha=alpha, beta=beta)
+ gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
log_pdf = gamma.log_prob(x)
log_pdf_values = log_pdf.eval()
@@ -90,7 +90,7 @@ class GammaTest(test.TestCase):
alpha_v = np.array([2.0, 4.0])
beta_v = 3.0
x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
- gamma = gamma_lib.Gamma(alpha=alpha, beta=beta)
+ gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
log_pdf = gamma.log_prob(x)
log_pdf_values = log_pdf.eval()
@@ -111,7 +111,7 @@ class GammaTest(test.TestCase):
beta_v = 3.0
x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
- gamma = gamma_lib.Gamma(alpha=alpha, beta=beta)
+ gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
expected_cdf = stats.gamma.cdf(x, alpha_v, scale=1 / beta_v)
cdf = gamma.cdf(x)
@@ -122,7 +122,7 @@ class GammaTest(test.TestCase):
with self.test_session():
alpha_v = np.array([1.0, 3.0, 2.5])
beta_v = np.array([1.0, 4.0, 5.0])
- gamma = gamma_lib.Gamma(alpha=alpha_v, beta=beta_v)
+ gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
expected_means = stats.gamma.mean(alpha_v, scale=1 / beta_v)
self.assertEqual(gamma.mean().get_shape(), (3,))
self.assertAllClose(gamma.mean().eval(), expected_means)
@@ -131,7 +131,7 @@ class GammaTest(test.TestCase):
with self.test_session():
alpha_v = np.array([5.5, 3.0, 2.5])
beta_v = np.array([1.0, 4.0, 5.0])
- gamma = gamma_lib.Gamma(alpha=alpha_v, beta=beta_v)
+ gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
expected_modes = (alpha_v - 1) / beta_v
self.assertEqual(gamma.mode().get_shape(), (3,))
self.assertAllClose(gamma.mode().eval(), expected_modes)
@@ -141,7 +141,9 @@ class GammaTest(test.TestCase):
# Mode will not be defined for the first entry.
alpha_v = np.array([0.5, 3.0, 2.5])
beta_v = np.array([1.0, 4.0, 5.0])
- gamma = gamma_lib.Gamma(alpha=alpha_v, beta=beta_v, allow_nan_stats=False)
+ gamma = gamma_lib.Gamma(concentration=alpha_v,
+ rate=beta_v,
+ allow_nan_stats=False)
with self.assertRaisesOpError("x < y"):
gamma.mode().eval()
@@ -150,7 +152,9 @@ class GammaTest(test.TestCase):
# Mode will not be defined for the first entry.
alpha_v = np.array([0.5, 3.0, 2.5])
beta_v = np.array([1.0, 4.0, 5.0])
- gamma = gamma_lib.Gamma(alpha=alpha_v, beta=beta_v, allow_nan_stats=True)
+ gamma = gamma_lib.Gamma(concentration=alpha_v,
+ rate=beta_v,
+ allow_nan_stats=True)
expected_modes = (alpha_v - 1) / beta_v
expected_modes[0] = np.nan
self.assertEqual(gamma.mode().get_shape(), (3,))
@@ -160,7 +164,7 @@ class GammaTest(test.TestCase):
with self.test_session():
alpha_v = np.array([1.0, 3.0, 2.5])
beta_v = np.array([1.0, 4.0, 5.0])
- gamma = gamma_lib.Gamma(alpha=alpha_v, beta=beta_v)
+ gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
expected_variances = stats.gamma.var(alpha_v, scale=1 / beta_v)
self.assertEqual(gamma.variance().get_shape(), (3,))
self.assertAllClose(gamma.variance().eval(), expected_variances)
@@ -169,7 +173,7 @@ class GammaTest(test.TestCase):
with self.test_session():
alpha_v = np.array([1.0, 3.0, 2.5])
beta_v = np.array([1.0, 4.0, 5.0])
- gamma = gamma_lib.Gamma(alpha=alpha_v, beta=beta_v)
+ gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
expected_stddev = stats.gamma.std(alpha_v, scale=1. / beta_v)
self.assertEqual(gamma.stddev().get_shape(), (3,))
self.assertAllClose(gamma.stddev().eval(), expected_stddev)
@@ -179,7 +183,7 @@ class GammaTest(test.TestCase):
alpha_v = np.array([1.0, 3.0, 2.5])
beta_v = np.array([1.0, 4.0, 5.0])
expected_entropy = stats.gamma.entropy(alpha_v, scale=1 / beta_v)
- gamma = gamma_lib.Gamma(alpha=alpha_v, beta=beta_v)
+ gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
self.assertEqual(gamma.entropy().get_shape(), (3,))
self.assertAllClose(gamma.entropy().eval(), expected_entropy)
@@ -190,7 +194,7 @@ class GammaTest(test.TestCase):
alpha = constant_op.constant(alpha_v)
beta = constant_op.constant(beta_v)
n = 100000
- gamma = gamma_lib.Gamma(alpha=alpha, beta=beta)
+ gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
samples = gamma.sample(n, seed=137)
sample_values = samples.eval()
self.assertEqual(samples.get_shape(), (n,))
@@ -213,7 +217,7 @@ class GammaTest(test.TestCase):
alpha = constant_op.constant(alpha_v)
beta = constant_op.constant(beta_v)
n = 100000
- gamma = gamma_lib.Gamma(alpha=alpha, beta=beta)
+ gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
samples = gamma.sample(n, seed=137)
sample_values = samples.eval()
self.assertEqual(samples.get_shape(), (n,))
@@ -233,7 +237,7 @@ class GammaTest(test.TestCase):
with session.Session():
alpha_v = np.array([np.arange(1, 101, dtype=np.float32)]) # 1 x 100
beta_v = np.array([np.arange(1, 11, dtype=np.float32)]).T # 10 x 1
- gamma = gamma_lib.Gamma(alpha=alpha_v, beta=beta_v)
+ gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
n = 10000
samples = gamma.sample(n, seed=137)
sample_values = samples.eval()
@@ -268,7 +272,7 @@ class GammaTest(test.TestCase):
def testGammaPdfOfSampleMultiDims(self):
with session.Session() as sess:
- gamma = gamma_lib.Gamma(alpha=[7., 11.], beta=[[5.], [6.]])
+ gamma = gamma_lib.Gamma(concentration=[7., 11.], rate=[[5.], [6.]])
num = 50000
samples = gamma.sample(num, seed=137)
pdfs = gamma.prob(samples)
@@ -304,22 +308,29 @@ class GammaTest(test.TestCase):
with self.test_session():
alpha_v = constant_op.constant(0.0, name="alpha")
beta_v = constant_op.constant(1.0, name="beta")
- gamma = gamma_lib.Gamma(alpha=alpha_v, beta=beta_v, validate_args=True)
+ gamma = gamma_lib.Gamma(concentration=alpha_v,
+ rate=beta_v,
+ validate_args=True)
with self.assertRaisesOpError("alpha"):
gamma.mean().eval()
alpha_v = constant_op.constant(1.0, name="alpha")
beta_v = constant_op.constant(0.0, name="beta")
- gamma = gamma_lib.Gamma(alpha=alpha_v, beta=beta_v, validate_args=True)
+ gamma = gamma_lib.Gamma(concentration=alpha_v,
+ rate=beta_v,
+ validate_args=True)
with self.assertRaisesOpError("beta"):
gamma.mean().eval()
- def testGammaWithSoftplusAlphaBeta(self):
+ def testGammaWithSoftplusConcentrationRate(self):
with self.test_session():
alpha_v = constant_op.constant([0.0, -2.1], name="alpha")
beta_v = constant_op.constant([1.0, -3.6], name="beta")
- gamma = gamma_lib.GammaWithSoftplusAlphaBeta(alpha=alpha_v, beta=beta_v)
- self.assertAllEqual(nn_ops.softplus(alpha_v).eval(), gamma.alpha.eval())
- self.assertAllEqual(nn_ops.softplus(beta_v).eval(), gamma.beta.eval())
+ gamma = gamma_lib.GammaWithSoftplusConcentrationRate(
+ concentration=alpha_v, rate=beta_v)
+ self.assertAllEqual(nn_ops.softplus(alpha_v).eval(),
+ gamma.concentration.eval())
+ self.assertAllEqual(nn_ops.softplus(beta_v).eval(),
+ gamma.rate.eval())
def testGammaGammaKL(self):
alpha0 = np.array([3.])
@@ -330,8 +341,8 @@ class GammaTest(test.TestCase):
# Build graph.
with self.test_session() as sess:
- g0 = gamma_lib.Gamma(alpha=alpha0, beta=beta0)
- g1 = gamma_lib.Gamma(alpha=alpha1, beta=beta1)
+ g0 = gamma_lib.Gamma(concentration=alpha0, rate=beta0)
+ g1 = gamma_lib.Gamma(concentration=alpha1, rate=beta1)
x = g0.sample(int(1e4), seed=0)
kl_sample = math_ops.reduce_mean(g0.log_prob(x) - g1.log_prob(x), 0)
kl_actual = kullback_leibler.kl(g0, g1)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/inverse_gamma_test.py b/tensorflow/contrib/distributions/python/kernel_tests/inverse_gamma_test.py
index 9bcca23966..6eb96ea9ff 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/inverse_gamma_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/inverse_gamma_test.py
@@ -33,7 +33,7 @@ class InverseGammaTest(test.TestCase):
with self.test_session():
alpha = constant_op.constant([3.0] * 5)
beta = constant_op.constant(11.0)
- inv_gamma = inverse_gamma.InverseGamma(alpha=alpha, beta=beta)
+ inv_gamma = inverse_gamma.InverseGamma(concentration=alpha, rate=beta)
self.assertEqual(inv_gamma.batch_shape_tensor().eval(), (5,))
self.assertEqual(inv_gamma.batch_shape,
@@ -50,7 +50,7 @@ class InverseGammaTest(test.TestCase):
alpha_v = 2.0
beta_v = 3.0
x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
- inv_gamma = inverse_gamma.InverseGamma(alpha=alpha, beta=beta)
+ inv_gamma = inverse_gamma.InverseGamma(concentration=alpha, rate=beta)
expected_log_pdf = stats.invgamma.logpdf(x, alpha_v, scale=beta_v)
log_pdf = inv_gamma.log_prob(x)
self.assertEqual(log_pdf.get_shape(), (6,))
@@ -68,7 +68,7 @@ class InverseGammaTest(test.TestCase):
alpha_v = np.array([2.0, 4.0])
beta_v = np.array([3.0, 4.0])
x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
- inv_gamma = inverse_gamma.InverseGamma(alpha=alpha, beta=beta)
+ inv_gamma = inverse_gamma.InverseGamma(concentration=alpha, rate=beta)
expected_log_pdf = stats.invgamma.logpdf(x, alpha_v, scale=beta_v)
log_pdf = inv_gamma.log_prob(x)
log_pdf_values = log_pdf.eval()
@@ -88,7 +88,7 @@ class InverseGammaTest(test.TestCase):
alpha_v = np.array([2.0, 4.0])
beta_v = 3.0
x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
- inv_gamma = inverse_gamma.InverseGamma(alpha=alpha, beta=beta)
+ inv_gamma = inverse_gamma.InverseGamma(concentration=alpha, rate=beta)
expected_log_pdf = stats.invgamma.logpdf(x, alpha_v, scale=beta_v)
log_pdf = inv_gamma.log_prob(x)
log_pdf_values = log_pdf.eval()
@@ -109,7 +109,7 @@ class InverseGammaTest(test.TestCase):
beta = constant_op.constant([beta_v] * batch_size)
x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
- inv_gamma = inverse_gamma.InverseGamma(alpha=alpha, beta=beta)
+ inv_gamma = inverse_gamma.InverseGamma(concentration=alpha, rate=beta)
expected_cdf = stats.invgamma.cdf(x, alpha_v, scale=beta_v)
cdf = inv_gamma.cdf(x)
@@ -120,7 +120,7 @@ class InverseGammaTest(test.TestCase):
with self.test_session():
alpha_v = np.array([5.5, 3.0, 2.5])
beta_v = np.array([1.0, 4.0, 5.0])
- inv_gamma = inverse_gamma.InverseGamma(alpha=alpha_v, beta=beta_v)
+ inv_gamma = inverse_gamma.InverseGamma(concentration=alpha_v, rate=beta_v)
expected_modes = beta_v / (alpha_v + 1)
self.assertEqual(inv_gamma.mode().get_shape(), (3,))
self.assertAllClose(inv_gamma.mode().eval(), expected_modes)
@@ -129,7 +129,7 @@ class InverseGammaTest(test.TestCase):
with self.test_session():
alpha_v = np.array([5.5, 3.0, 2.5])
beta_v = np.array([1.0, 4.0, 5.0])
- inv_gamma = inverse_gamma.InverseGamma(alpha=alpha_v, beta=beta_v)
+ inv_gamma = inverse_gamma.InverseGamma(concentration=alpha_v, rate=beta_v)
expected_means = stats.invgamma.mean(alpha_v, scale=beta_v)
self.assertEqual(inv_gamma.mean().get_shape(), (3,))
self.assertAllClose(inv_gamma.mean().eval(), expected_means)
@@ -140,7 +140,7 @@ class InverseGammaTest(test.TestCase):
alpha_v = np.array([1.0, 3.0, 2.5])
beta_v = np.array([1.0, 4.0, 5.0])
inv_gamma = inverse_gamma.InverseGamma(
- alpha=alpha_v, beta=beta_v, allow_nan_stats=False)
+ concentration=alpha_v, rate=beta_v, allow_nan_stats=False)
with self.assertRaisesOpError("x < y"):
inv_gamma.mean().eval()
@@ -150,7 +150,7 @@ class InverseGammaTest(test.TestCase):
alpha_v = np.array([0.5, 1.0, 3.0, 2.5])
beta_v = np.array([1.0, 2.0, 4.0, 5.0])
inv_gamma = inverse_gamma.InverseGamma(
- alpha=alpha_v, beta=beta_v, allow_nan_stats=True)
+ concentration=alpha_v, rate=beta_v, allow_nan_stats=True)
expected_means = beta_v / (alpha_v - 1)
expected_means[0] = np.nan
expected_means[1] = np.nan
@@ -161,7 +161,7 @@ class InverseGammaTest(test.TestCase):
with self.test_session():
alpha_v = np.array([7.0, 3.0, 2.5])
beta_v = np.array([1.0, 4.0, 5.0])
- inv_gamma = inverse_gamma.InverseGamma(alpha=alpha_v, beta=beta_v)
+ inv_gamma = inverse_gamma.InverseGamma(concentration=alpha_v, rate=beta_v)
expected_variances = stats.invgamma.var(alpha_v, scale=beta_v)
self.assertEqual(inv_gamma.variance().get_shape(), (3,))
self.assertAllClose(inv_gamma.variance().eval(), expected_variances)
@@ -171,7 +171,7 @@ class InverseGammaTest(test.TestCase):
alpha_v = np.array([1.5, 3.0, 2.5])
beta_v = np.array([1.0, 4.0, 5.0])
inv_gamma = inverse_gamma.InverseGamma(
- alpha=alpha_v, beta=beta_v, allow_nan_stats=False)
+ concentration=alpha_v, rate=beta_v, allow_nan_stats=False)
with self.assertRaisesOpError("x < y"):
inv_gamma.variance().eval()
@@ -180,7 +180,7 @@ class InverseGammaTest(test.TestCase):
alpha_v = np.array([1.5, 3.0, 2.5])
beta_v = np.array([1.0, 4.0, 5.0])
inv_gamma = inverse_gamma.InverseGamma(
- alpha=alpha_v, beta=beta_v, allow_nan_stats=True)
+ concentration=alpha_v, rate=beta_v, allow_nan_stats=True)
expected_variances = stats.invgamma.var(alpha_v, scale=beta_v)
expected_variances[0] = np.nan
self.assertEqual(inv_gamma.variance().get_shape(), (3,))
@@ -191,7 +191,7 @@ class InverseGammaTest(test.TestCase):
alpha_v = np.array([1.0, 3.0, 2.5])
beta_v = np.array([1.0, 4.0, 5.0])
expected_entropy = stats.invgamma.entropy(alpha_v, scale=beta_v)
- inv_gamma = inverse_gamma.InverseGamma(alpha=alpha_v, beta=beta_v)
+ inv_gamma = inverse_gamma.InverseGamma(concentration=alpha_v, rate=beta_v)
self.assertEqual(inv_gamma.entropy().get_shape(), (3,))
self.assertAllClose(inv_gamma.entropy().eval(), expected_entropy)
@@ -202,7 +202,7 @@ class InverseGammaTest(test.TestCase):
alpha = constant_op.constant(alpha_v)
beta = constant_op.constant(beta_v)
n = 100000
- inv_gamma = inverse_gamma.InverseGamma(alpha=alpha, beta=beta)
+ inv_gamma = inverse_gamma.InverseGamma(concentration=alpha, rate=beta)
samples = inv_gamma.sample(n, seed=137)
sample_values = samples.eval()
self.assertEqual(samples.get_shape(), (n,))
@@ -222,7 +222,7 @@ class InverseGammaTest(test.TestCase):
with session.Session():
alpha_v = np.array([np.arange(3, 103, dtype=np.float32)]) # 1 x 100
beta_v = np.array([np.arange(1, 11, dtype=np.float32)]).T # 10 x 1
- inv_gamma = inverse_gamma.InverseGamma(alpha=alpha_v, beta=beta_v)
+ inv_gamma = inverse_gamma.InverseGamma(concentration=alpha_v, rate=beta_v)
n = 10000
samples = inv_gamma.sample(n, seed=137)
sample_values = samples.eval()
@@ -257,7 +257,9 @@ class InverseGammaTest(test.TestCase):
def testInverseGammaPdfOfSampleMultiDims(self):
with session.Session() as sess:
- inv_gamma = inverse_gamma.InverseGamma(alpha=[7., 11.], beta=[[5.], [6.]])
+ inv_gamma = inverse_gamma.InverseGamma(
+ concentration=[7., 11.],
+ rate=[[5.], [6.]])
num = 50000
samples = inv_gamma.sample(num, seed=137)
pdfs = inv_gamma.prob(samples)
@@ -294,24 +296,26 @@ class InverseGammaTest(test.TestCase):
alpha_v = constant_op.constant(0.0, name="alpha")
beta_v = constant_op.constant(1.0, name="beta")
inv_gamma = inverse_gamma.InverseGamma(
- alpha=alpha_v, beta=beta_v, validate_args=True)
+ concentration=alpha_v, rate=beta_v, validate_args=True)
with self.assertRaisesOpError("alpha"):
inv_gamma.mean().eval()
alpha_v = constant_op.constant(1.0, name="alpha")
beta_v = constant_op.constant(0.0, name="beta")
inv_gamma = inverse_gamma.InverseGamma(
- alpha=alpha_v, beta=beta_v, validate_args=True)
+ concentration=alpha_v, rate=beta_v, validate_args=True)
with self.assertRaisesOpError("beta"):
inv_gamma.mean().eval()
- def testInverseGammaWithSoftplusAlphaBeta(self):
+ def testInverseGammaWithSoftplusConcentrationRate(self):
with self.test_session():
alpha = constant_op.constant([-0.1, -2.9], name="alpha")
beta = constant_op.constant([1.0, -4.8], name="beta")
- inv_gamma = inverse_gamma.InverseGammaWithSoftplusAlphaBeta(
- alpha=alpha, beta=beta, validate_args=True)
- self.assertAllClose(nn_ops.softplus(alpha).eval(), inv_gamma.alpha.eval())
- self.assertAllClose(nn_ops.softplus(beta).eval(), inv_gamma.beta.eval())
+ inv_gamma = inverse_gamma.InverseGammaWithSoftplusConcentrationRate(
+ concentration=alpha, rate=beta, validate_args=True)
+ self.assertAllClose(nn_ops.softplus(alpha).eval(),
+ inv_gamma.concentration.eval())
+ self.assertAllClose(nn_ops.softplus(beta).eval(),
+ inv_gamma.rate.eval())
if __name__ == "__main__":
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py b/tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py
index c88af121a6..0adaf7d816 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py
@@ -30,7 +30,7 @@ class PoissonTest(test.TestCase):
def testPoissonShape(self):
with self.test_session():
lam = constant_op.constant([3.0] * 5)
- poisson = poisson_lib.Poisson(lam=lam)
+ poisson = poisson_lib.Poisson(rate=lam)
self.assertEqual(poisson.batch_shape_tensor().eval(), (5,))
self.assertEqual(poisson.batch_shape, tensor_shape.TensorShape([5]))
@@ -38,16 +38,12 @@ class PoissonTest(test.TestCase):
self.assertEqual(poisson.event_shape, tensor_shape.TensorShape([]))
def testInvalidLam(self):
- invalid_lams = [
- -.01,
- 0,
- -2.,
- ]
+ invalid_lams = [-.01, 0, -2.]
for lam in invalid_lams:
with self.test_session():
with self.assertRaisesOpError("Condition x > 0"):
- poisson = poisson_lib.Poisson(lam=lam, validate_args=True)
- poisson.lam.eval()
+ poisson = poisson_lib.Poisson(rate=lam, validate_args=True)
+ poisson.rate.eval()
def testPoissonLogPmf(self):
with self.test_session():
@@ -55,7 +51,7 @@ class PoissonTest(test.TestCase):
lam = constant_op.constant([3.0] * batch_size)
lam_v = 3.0
x = [2., 3., 4., 5., 6., 7.]
- poisson = poisson_lib.Poisson(lam=lam)
+ poisson = poisson_lib.Poisson(rate=lam)
log_pmf = poisson.log_prob(x)
self.assertEqual(log_pmf.get_shape(), (6,))
self.assertAllClose(log_pmf.eval(), stats.poisson.logpmf(x, lam_v))
@@ -69,7 +65,7 @@ class PoissonTest(test.TestCase):
batch_size = 6
lam = constant_op.constant([3.0] * batch_size)
x = [2.5, 3.2, 4.3, 5.1, 6., 7.]
- poisson = poisson_lib.Poisson(lam=lam, validate_args=True)
+ poisson = poisson_lib.Poisson(rate=lam, validate_args=True)
# Non-integer
with self.assertRaisesOpError("x has non-integer components"):
@@ -80,7 +76,7 @@ class PoissonTest(test.TestCase):
log_pmf = poisson.log_prob([-1.])
log_pmf.eval()
- poisson = poisson_lib.Poisson(lam=lam, validate_args=False)
+ poisson = poisson_lib.Poisson(rate=lam, validate_args=False)
log_pmf = poisson.log_prob(x)
self.assertEqual(log_pmf.get_shape(), (6,))
pmf = poisson.prob(x)
@@ -93,7 +89,7 @@ class PoissonTest(test.TestCase):
lam_v = [2.0, 4.0, 5.0]
x = np.array([[2., 3., 4., 5., 6., 7.]], dtype=np.float32).T
- poisson = poisson_lib.Poisson(lam=lam)
+ poisson = poisson_lib.Poisson(rate=lam)
log_pmf = poisson.log_prob(x)
self.assertEqual(log_pmf.get_shape(), (6, 3))
self.assertAllClose(log_pmf.eval(), stats.poisson.logpmf(x, lam_v))
@@ -109,7 +105,7 @@ class PoissonTest(test.TestCase):
lam_v = 3.0
x = [2.2, 3.1, 4., 5.5, 6., 7.]
- poisson = poisson_lib.Poisson(lam=lam)
+ poisson = poisson_lib.Poisson(rate=lam)
log_cdf = poisson.log_cdf(x)
self.assertEqual(log_cdf.get_shape(), (6,))
self.assertAllClose(log_cdf.eval(), stats.poisson.logcdf(x, lam_v))
@@ -125,7 +121,7 @@ class PoissonTest(test.TestCase):
lam_v = [2.0, 4.0, 5.0]
x = np.array([[2.2, 3.1, 4., 5.5, 6., 7.]], dtype=np.float32).T
- poisson = poisson_lib.Poisson(lam=lam)
+ poisson = poisson_lib.Poisson(rate=lam)
log_cdf = poisson.log_cdf(x)
self.assertEqual(log_cdf.get_shape(), (6, 3))
self.assertAllClose(log_cdf.eval(), stats.poisson.logcdf(x, lam_v))
@@ -137,7 +133,7 @@ class PoissonTest(test.TestCase):
def testPoissonMean(self):
with self.test_session():
lam_v = [1.0, 3.0, 2.5]
- poisson = poisson_lib.Poisson(lam=lam_v)
+ poisson = poisson_lib.Poisson(rate=lam_v)
self.assertEqual(poisson.mean().get_shape(), (3,))
self.assertAllClose(poisson.mean().eval(), stats.poisson.mean(lam_v))
self.assertAllClose(poisson.mean().eval(), lam_v)
@@ -145,7 +141,7 @@ class PoissonTest(test.TestCase):
def testPoissonVariance(self):
with self.test_session():
lam_v = [1.0, 3.0, 2.5]
- poisson = poisson_lib.Poisson(lam=lam_v)
+ poisson = poisson_lib.Poisson(rate=lam_v)
self.assertEqual(poisson.variance().get_shape(), (3,))
self.assertAllClose(poisson.variance().eval(), stats.poisson.var(lam_v))
self.assertAllClose(poisson.variance().eval(), lam_v)
@@ -153,7 +149,7 @@ class PoissonTest(test.TestCase):
def testPoissonStd(self):
with self.test_session():
lam_v = [1.0, 3.0, 2.5]
- poisson = poisson_lib.Poisson(lam=lam_v)
+ poisson = poisson_lib.Poisson(rate=lam_v)
self.assertEqual(poisson.stddev().get_shape(), (3,))
self.assertAllClose(poisson.stddev().eval(), stats.poisson.std(lam_v))
self.assertAllClose(poisson.stddev().eval(), np.sqrt(lam_v))
@@ -161,14 +157,14 @@ class PoissonTest(test.TestCase):
def testPoissonMode(self):
with self.test_session():
lam_v = [1.0, 3.0, 2.5, 3.2, 1.1, 0.05]
- poisson = poisson_lib.Poisson(lam=lam_v)
+ poisson = poisson_lib.Poisson(rate=lam_v)
self.assertEqual(poisson.mode().get_shape(), (6,))
self.assertAllClose(poisson.mode().eval(), np.floor(lam_v))
def testPoissonMultipleMode(self):
with self.test_session():
lam_v = [1.0, 3.0, 2.0, 4.0, 5.0, 10.0]
- poisson = poisson_lib.Poisson(lam=lam_v)
+ poisson = poisson_lib.Poisson(rate=lam_v)
# For the case where lam is an integer, the modes are: lam and lam - 1.
# In this case, we get back the larger of the two modes.
self.assertEqual((6,), poisson.mode().get_shape())
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/quantized_distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/quantized_distribution_test.py
index de7141e660..0e2d143732 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/quantized_distribution_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/quantized_distribution_test.py
@@ -201,7 +201,7 @@ class QuantizedDistributionTest(test.TestCase):
# integers. Hence, F(X) (see below) will not be uniform exactly.
with self.test_session():
qdist = distributions.QuantizedDistribution(
- distribution=distributions.Exponential(lam=0.01))
+ distribution=distributions.Exponential(rate=0.01))
# X ~ QuantizedExponential
x = qdist.sample(10000, seed=42)
# Z = F(X), should be Uniform.
@@ -224,7 +224,7 @@ class QuantizedDistributionTest(test.TestCase):
# Make an exponential with mean 5.
with self.test_session():
qdist = distributions.QuantizedDistribution(
- distribution=distributions.Exponential(lam=0.2))
+ distribution=distributions.Exponential(rate=0.2))
# Standard error should be less than 1 / (2 * sqrt(n_samples))
n_samples = 10000
stddev_err_bound = 1 / (2 * np.sqrt(n_samples))
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py b/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py
index 754a896f75..1fa6ca0906 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py
@@ -67,10 +67,13 @@ class WishartCholeskyTest(test.TestCase):
with self.test_session():
def entropy_alt(w):
- return (w.log_normalizing_constant() - 0.5 * (w.df - w.dimension - 1.) *
- w.mean_log_det() + 0.5 * w.df * w.dimension).eval()
+ return (
+ w.log_normalization()
+ - 0.5 * (w.df - w.dimension - 1.) * w.mean_log_det()
+ + 0.5 * w.df * w.dimension).eval()
- w = distributions.WishartCholesky(df=4, scale=chol(make_pd(1., 2)))
+ w = distributions.WishartCholesky(df=4,
+ scale=chol(make_pd(1., 2)))
self.assertAllClose(w.entropy().eval(), entropy_alt(w))
w = distributions.WishartCholesky(df=5, scale=[[1.]])
@@ -204,7 +207,9 @@ class WishartCholeskyTest(test.TestCase):
# This test checks that batches don't interfere with correctness.
w = distributions.WishartCholesky(
- df=[2, 3, 4, 5], scale=chol_x, cholesky_input_output_matrices=True)
+ df=[2, 3, 4, 5],
+ scale=chol_x,
+ cholesky_input_output_matrices=True)
self.assertAllClose(log_prob_df_seq, w.log_prob(chol_x).eval())
# Now we test various constructions of Wishart with different sample
@@ -221,10 +226,15 @@ class WishartCholeskyTest(test.TestCase):
-20.951582705289454,
])
- for w in (distributions.WishartCholesky(
- df=4, scale=chol_x[0], cholesky_input_output_matrices=False),
- distributions.WishartFull(
- df=4, scale=x[0], cholesky_input_output_matrices=False)):
+ for w in (
+ distributions.WishartCholesky(
+ df=4,
+ scale=chol_x[0],
+ cholesky_input_output_matrices=False),
+ distributions.WishartFull(
+ df=4,
+ scale=x[0],
+ cholesky_input_output_matrices=False)):
self.assertAllEqual((2, 2), w.event_shape_tensor().eval())
self.assertEqual(2, w.dimension.eval())
self.assertAllClose(log_prob[0], w.log_prob(x[0]).eval())
@@ -238,10 +248,15 @@ class WishartCholeskyTest(test.TestCase):
self.assertAllEqual((2, 2),
w.log_prob(np.reshape(x, (2, 2, 2, 2))).get_shape())
- for w in (distributions.WishartCholesky(
- df=4, scale=chol_x[0], cholesky_input_output_matrices=True),
- distributions.WishartFull(
- df=4, scale=x[0], cholesky_input_output_matrices=True)):
+ for w in (
+ distributions.WishartCholesky(
+ df=4,
+ scale=chol_x[0],
+ cholesky_input_output_matrices=True),
+ distributions.WishartFull(
+ df=4,
+ scale=x[0],
+ cholesky_input_output_matrices=True)):
self.assertAllEqual((2, 2), w.event_shape_tensor().eval())
self.assertEqual(2, w.dimension.eval())
self.assertAllClose(log_prob[0], w.log_prob(chol_x[0]).eval())
@@ -315,7 +330,9 @@ class WishartCholeskyTest(test.TestCase):
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"cannot be less than"):
chol_w = distributions.WishartCholesky(
- df=df_deferred, scale=chol_scale_deferred, validate_args=True)
+ df=df_deferred,
+ scale=chol_scale_deferred,
+ validate_args=True)
sess.run(chol_w.log_prob(np.asarray(
x, dtype=np.float32)),
feed_dict={df_deferred: 2.,
@@ -336,7 +353,9 @@ class WishartCholeskyTest(test.TestCase):
# Ensure no assertions.
chol_w = distributions.WishartCholesky(
- df=df_deferred, scale=chol_scale_deferred, validate_args=False)
+ df=df_deferred,
+ scale=chol_scale_deferred,
+ validate_args=False)
sess.run(chol_w.log_prob(np.asarray(
x, dtype=np.float32)),
feed_dict={df_deferred: 4,
diff --git a/tensorflow/contrib/distributions/python/ops/beta.py b/tensorflow/contrib/distributions/python/ops/beta.py
index baab0503f5..b4e9bc5b2e 100644
--- a/tensorflow/contrib/distributions/python/ops/beta.py
+++ b/tensorflow/contrib/distributions/python/ops/beta.py
@@ -36,158 +36,170 @@ from tensorflow.python.ops import nn
from tensorflow.python.ops import random_ops
-_beta_prob_note = """
- Note that the argument `x` must be a non-negative floating point tensor
- whose shape can be broadcast with `self.a` and `self.b`. For fixed leading
- dimensions, the last dimension represents counts for the corresponding Beta
- distribution in `self.a` and `self.b`. `x` is only legal if `0 < x < 1`.
-"""
+__all__ = [
+ "Beta",
+ "BetaWithSoftplusConcentration",
+]
-class Beta(distribution.Distribution):
- """Beta distribution.
+_beta_sample_note = """Note: `x` must have dtype `self.dtype` and be in
+`[0, 1].` It must have a shape compatible with `self.batch_shape()`."""
- This distribution is parameterized by `a` and `b` which are shape
- parameters.
- #### Mathematical details
+class Beta(distribution.Distribution):
+ """Beta distribution.
- The Beta is a distribution over the interval (0, 1).
- The distribution has hyperparameters `a` and `b` and
- probability mass function (pdf):
+ The Beta distribution is defined over the `(0, 1)` interval using parameters
+ `concentration1` (aka "alpha") and `concentration0` (aka "beta").
- ```pdf(x) = 1 / Beta(a, b) * x^(a - 1) * (1 - x)^(b - 1)```
+ #### Mathematical Details
- where `Beta(a, b) = Gamma(a) * Gamma(b) / Gamma(a + b)`
- is the beta function.
+ The probability density function (pdf) is,
+ ```none
+ pdf(x; alpha, beta) = x**(alpha - 1) (1 - x)**(beta - 1) / Z
+ Z = Gamma(alpha) Gamma(beta) / Gamma(alpha + beta)
+ ```
- This class provides methods to create indexed batches of Beta
- distributions. One entry of the broadcasted
- shape represents of `a` and `b` represents one single Beta distribution.
- When calling distribution functions (e.g. `dist.prob(x)`), `a`, `b`
- and `x` are broadcast to the same shape (if possible).
- Every entry in a/b/x corresponds to a single Beta distribution.
+ where:
- #### Examples
+ * `concentration1 = alpha`,
+ * `concentration0 = beta`,
+ * `Z` is the normalization constant, and,
+ * `Gamma` is the [gamma function](
+ https://en.wikipedia.org/wiki/Gamma_function).
- Creates 3 distributions.
- The distribution functions can be evaluated on x.
+ The concentration parameters represent mean total counts of a `1` or a `0`,
+ i.e.,
- ```python
- a = [1, 2, 3]
- b = [1, 2, 3]
- dist = Beta(a, b)
+ ```none
+ concentration1 = alpha = mean * total_concentration
+ concentration0 = beta = (1. - mean) * total_concentration
```
- ```python
- # x same shape as a.
- x = [.2, .3, .7]
- dist.prob(x) # Shape [3]
-
- # a/b will be broadcast to [[1, 2, 3], [1, 2, 3]] to match x.
- x = [[.1, .4, .5], [.2, .3, .5]]
- dist.prob(x) # Shape [2, 3]
+ where `mean` in `(0, 1)` and `total_concentration` is a positive real number
+ representing a mean `total_count = concentration1 + concentration0`.
- # a/b will be broadcast to shape [5, 7, 3] to match x.
- x = [[...]] # Shape [5, 7, 3]
- dist.prob(x) # Shape [5, 7, 3]
- ```
+ Distribution parameters are automatically broadcast in all functions; see
+ examples for details.
- Creates a 2-batch of 3-class distributions.
+ #### Examples
```python
- a = [[1, 2, 3], [4, 5, 6]] # Shape [2, 3]
- b = 5 # Shape []
- dist = Beta(a, b)
+ # Create a batch of three Beta distributions.
+ alpha = [1, 2, 3]
+ beta = [1, 2, 3]
+ dist = Beta(alpha, beta)
+
+ dist.sample([4, 5]) # Shape [4, 5, 3]
+
+ # `x` has three batch entries, each with two samples.
+ x = [[.1, .4, .5],
+ [.2, .3, .5]]
+ # Calculate the probability of each pair of samples under the corresponding
+ # distribution in `dist`.
+ dist.prob(x) # Shape [2, 3]
+ ```
- # x will be broadcast to [[.2, .3, .9], [.2, .3, .9]] to match a/b.
- x = [.2, .3, .9]
- dist.prob(x) # Shape [2]
+ ```python
+ # Create batch_shape=[2, 3] via parameter broadcast:
+ alpha = [[1.], [2]] # Shape [2, 1]
+ beta = [3., 4, 5] # Shape [3]
+ dist = Beta(alpha, beta)
+
+ # alpha broadcast as: [[1., 1, 1,],
+ # [2, 2, 2]]
+ # beta broadcast as: [[3., 4, 5],
+ # [3, 4, 5]]
+ # batch_Shape [2, 3]
+ dist.sample([4, 5]) # Shape [4, 5, 2, 3]
+
+ x = [.2, .3, .5]
+ # x will be broadcast as [[.2, .3, .5],
+ # [.2, .3, .5]],
+ # thus matching batch_shape [2, 3].
+ dist.prob(x) # Shape [2, 3]
```
"""
- def __init__(self, a, b, validate_args=False, allow_nan_stats=True,
+ def __init__(self,
+ concentration1=None,
+ concentration0=None,
+ validate_args=False,
+ allow_nan_stats=True,
name="Beta"):
"""Initialize a batch of Beta distributions.
Args:
- a: Positive floating point tensor with shape broadcastable to
- `[N1,..., Nm]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm`
- different Beta distributions. This also defines the
- dtype of the distribution.
- b: Positive floating point tensor with shape broadcastable to
- `[N1,..., Nm]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm`
- different Beta distributions.
- validate_args: `Boolean`, default `False`. Whether to assert valid
- values for parameters `a`, `b`, and `x` in `prob` and `log_prob`.
- If `False` and inputs are invalid, correct behavior is not guaranteed.
- allow_nan_stats: `Boolean`, default `True`. If `False`, raise an
- exception if a statistic (e.g. mean/mode/etc...) is undefined for any
- batch member. If `True`, batch members with valid parameters leading to
- undefined statistics will return NaN for this statistic.
- name: The name to prefix Ops created by this distribution class.
-
- Examples:
-
- ```python
- # Define 1-batch.
- dist = Beta(1.1, 2.0)
-
- # Define a 2-batch.
- dist = Beta([1.0, 2.0], [4.0, 5.0])
- ```
-
+ concentration1: Positive floating-point `Tensor` indicating mean
+ number of successes; aka "alpha". Implies `self.dtype` and
+ `self.batch_shape`, i.e.,
+ `concentration1.shape = [N1, N2, ..., Nm] = self.batch_shape`.
+ concentration0: Positive floating-point `Tensor` indicating mean
+ number of failures; aka "beta". Otherwise has same semantics as
+ `concentration1`.
+ validate_args: Python `Boolean`, default `False`. When `True` distribution
+ parameters are checked for validity despite possibly degrading runtime
+ performance. When `False` invalid inputs may silently render incorrect
+ outputs.
+ allow_nan_stats: Python `Boolean`, default `True`. When `True`, statistics
+ (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
+ result is undefined. When `False`, an exception is raised if one or
+ more of the statistic's batch members are undefined.
+ name: `String` name prefixed to Ops created by this class.
"""
parameters = locals()
parameters.pop("self")
- with ops.name_scope(name, values=[a, b]) as ns:
- with ops.control_dependencies([
- check_ops.assert_positive(a),
- check_ops.assert_positive(b),
- ] if validate_args else []):
- self._a = array_ops.identity(a, name="a")
- self._b = array_ops.identity(b, name="b")
- contrib_tensor_util.assert_same_float_dtype((self._a, self._b))
- # Used for mean/mode/variance/entropy/sampling computations
- self._a_b_sum = self._a + self._b
+ with ops.name_scope(name, values=[concentration1,
+ concentration0]) as ns:
+ self._concentration1 = self._maybe_assert_valid_concentration(
+ ops.convert_to_tensor(concentration1, name="concentration1"),
+ validate_args)
+ self._concentration0 = self._maybe_assert_valid_concentration(
+ ops.convert_to_tensor(concentration0, name="concentration0"),
+ validate_args)
+ contrib_tensor_util.assert_same_float_dtype([
+ self._concentration1, self._concentration0])
+ self._total_concentration = self._concentration1 + self._concentration0
super(Beta, self).__init__(
- dtype=self._a_b_sum.dtype,
+ dtype=self._total_concentration.dtype,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
is_continuous=True,
reparameterization_type=distribution.NOT_REPARAMETERIZED,
parameters=parameters,
- graph_parents=[self._a, self._b, self._a_b_sum],
+ graph_parents=[self._concentration1,
+ self._concentration0,
+ self._total_concentration],
name=ns)
@staticmethod
def _param_shapes(sample_shape):
- return dict(
- zip(("a", "b"), ([ops.convert_to_tensor(
- sample_shape, dtype=dtypes.int32)] * 2)))
+ return dict(zip(
+ ["concentration1", "concentration0"],
+ [ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)] * 2))
@property
- def a(self):
- """Shape parameter."""
- return self._a
+ def concentration1(self):
+ """Concentration parameter associated with a `1` outcome."""
+ return self._concentration1
@property
- def b(self):
- """Shape parameter."""
- return self._b
+ def concentration0(self):
+ """Concentration parameter associated with a `0` outcome."""
+ return self._concentration0
@property
- def a_b_sum(self):
- """Sum of parameters."""
- return self._a_b_sum
+ def total_concentration(self):
+ """Sum of concentration parameters."""
+ return self._total_concentration
def _batch_shape_tensor(self):
- return array_ops.shape(self.a_b_sum)
+ return array_ops.shape(self.total_concentration)
def _batch_shape(self):
- return self.a_b_sum.get_shape()
+ return self.total_concentration.get_shape()
def _event_shape_tensor(self):
return constant_op.constant([], dtype=dtypes.int32)
@@ -196,103 +208,130 @@ class Beta(distribution.Distribution):
return tensor_shape.scalar()
def _sample_n(self, n, seed=None):
- a = array_ops.ones_like(self.a_b_sum, dtype=self.dtype) * self.a
- b = array_ops.ones_like(self.a_b_sum, dtype=self.dtype) * self.b
+ expanded_concentration1 = array_ops.ones_like(
+ self.total_concentration, dtype=self.dtype) * self.concentration1
+ expanded_concentration0 = array_ops.ones_like(
+ self.total_concentration, dtype=self.dtype) * self.concentration0
gamma1_sample = random_ops.random_gamma(
- [n,], a, dtype=self.dtype, seed=seed)
+ shape=[n],
+ alpha=expanded_concentration1,
+ dtype=self.dtype,
+ seed=seed)
gamma2_sample = random_ops.random_gamma(
- [n,], b, dtype=self.dtype,
+ shape=[n],
+ alpha=expanded_concentration0,
+ dtype=self.dtype,
seed=distribution_util.gen_new_seed(seed, "beta"))
beta_sample = gamma1_sample / (gamma1_sample + gamma2_sample)
return beta_sample
+ @distribution_util.AppendDocstring(_beta_sample_note)
def _log_prob(self, x):
- x = self._assert_valid_sample(x)
- log_unnormalized_prob = ((self.a - 1.) * math_ops.log(x) +
- (self.b - 1.) * math_ops.log(1. - x))
- log_normalization = (math_ops.lgamma(self.a) +
- math_ops.lgamma(self.b) -
- math_ops.lgamma(self.a_b_sum))
- return log_unnormalized_prob - log_normalization
-
- @distribution_util.AppendDocstring(_beta_prob_note)
+ return self._log_unnormalized_prob(x) - self._log_normalization()
+
+ @distribution_util.AppendDocstring(_beta_sample_note)
def _prob(self, x):
return math_ops.exp(self._log_prob(x))
- @distribution_util.AppendDocstring(_beta_prob_note)
+ @distribution_util.AppendDocstring(_beta_sample_note)
def _log_cdf(self, x):
return math_ops.log(self._cdf(x))
+ @distribution_util.AppendDocstring(_beta_sample_note)
def _cdf(self, x):
- return math_ops.betainc(self.a, self.b, x)
+ return math_ops.betainc(self.concentration1, self.concentration0, x)
+
+ def _log_unnormalized_prob(self, x):
+ x = self._maybe_assert_valid_sample(x)
+ return ((self.concentration1 - 1.) * math_ops.log(x)
+ + (self.concentration0 - 1.) * math_ops.log1p(-x))
+
+ def _log_normalization(self):
+ return (math_ops.lgamma(self.concentration1)
+ + math_ops.lgamma(self.concentration0)
+ - math_ops.lgamma(self.total_concentration))
def _entropy(self):
- return (math_ops.lgamma(self.a) -
- (self.a - 1.) * math_ops.digamma(self.a) +
- math_ops.lgamma(self.b) -
- (self.b - 1.) * math_ops.digamma(self.b) -
- math_ops.lgamma(self.a_b_sum) +
- (self.a_b_sum - 2.) * math_ops.digamma(self.a_b_sum))
+ return (
+ self._log_normalization()
+ - (self.concentration1 - 1.) * math_ops.digamma(self.concentration1)
+ - (self.concentration0 - 1.) * math_ops.digamma(self.concentration0)
+ + ((self.total_concentration - 2.) *
+ math_ops.digamma(self.total_concentration)))
def _mean(self):
- return self.a / self.a_b_sum
+ return self._concentration1 / self._total_concentration
def _variance(self):
- return (self.a * self.b) / (self.a_b_sum**2. * (self.a_b_sum + 1.))
+ return self._mean() * (1. - self._mean()) / (1. + self.total_concentration)
@distribution_util.AppendDocstring(
- """Note that the mode for the Beta distribution is only defined
- when `a > 1`, `b > 1`. This returns the mode when `a > 1` and `b > 1`,
- and `NaN` otherwise. If `self.allow_nan_stats` is `False`, an exception
- will be raised rather than returning `NaN`.""")
+ """Note: The mode is undefined when `concentration1 <= 1` or
+ `concentration0 <= 1`. If `self.allow_nan_stats` is `True`, `NaN`
+ is used for undefined modes. If `self.allow_nan_stats` is `False` an
+ exception is raised when one or more modes are undefined.""")
def _mode(self):
- mode = (self.a - 1.)/ (self.a_b_sum - 2.)
+ mode = (self.concentration1 - 1.) / (self.total_concentration - 2.)
if self.allow_nan_stats:
- nan = np.array(np.nan, dtype=self.dtype.as_numpy_dtype())
- return array_ops.where(
- math_ops.logical_and(
- math_ops.greater(self.a, 1.),
- math_ops.greater(self.b, 1.)),
- mode,
- array_ops.fill(self.batch_shape_tensor(), nan, name="nan"))
- else:
- return control_flow_ops.with_dependencies([
- check_ops.assert_less(
- array_ops.ones((), dtype=self.dtype), self.a,
- message="Mode not defined for components of a <= 1."),
- check_ops.assert_less(
- array_ops.ones((), dtype=self.dtype), self.b,
- message="Mode not defined for components of b <= 1."),
- ], mode)
-
- def _assert_valid_sample(self, x):
- """Check x for proper shape, values, then return tensor version."""
- if not self.validate_args: return x
+ nan = array_ops.fill(
+ self.batch_shape_tensor(),
+ np.array(np.nan, dtype=self.dtype.as_numpy_dtype()),
+ name="nan")
+ is_defined = math_ops.logical_and(self.concentration1 > 1.,
+ self.concentration0 > 1.)
+ return array_ops.where(is_defined, mode, nan)
+ return control_flow_ops.with_dependencies([
+ check_ops.assert_less(
+ array_ops.ones([], dtype=self.dtype),
+ self.concentration1,
+ message="Mode undefined for concentration1 <= 1."),
+ check_ops.assert_less(
+ array_ops.ones([], dtype=self.dtype),
+ self.concentration0,
+ message="Mode undefined for concentration0 <= 1.")
+ ], mode)
+
+ def _maybe_assert_valid_concentration(self, concentration, validate_args):
+ """Checks the validity of a concentration parameter."""
+ if not validate_args:
+ return concentration
+ return control_flow_ops.with_dependencies([
+ check_ops.assert_positive(
+ concentration,
+ message="Concentration parameter must be positive."),
+ ], concentration)
+
+ def _maybe_assert_valid_sample(self, x):
+ """Checks the validity of a sample."""
+ if not self.validate_args:
+ return x
return control_flow_ops.with_dependencies([
check_ops.assert_positive(
x,
- message="Negative events lie outside Beta distribution support."),
+ message="sample must be positive"),
check_ops.assert_less(
- x, array_ops.ones((), self.dtype),
- message="Event>=1 lies outside Beta distribution support."),
+ x, array_ops.ones([], self.dtype),
+ message="sample must be no larger than `1`."),
], x)
-class BetaWithSoftplusAB(Beta):
- """Beta with softplus transform on `a` and `b`."""
+class BetaWithSoftplusConcentration(Beta):
+ """Beta with softplus transform of `concentration1` and `concentration0`."""
def __init__(self,
- a,
- b,
+ concentration1,
+ concentration0,
validate_args=False,
allow_nan_stats=True,
- name="BetaWithSoftplusAB"):
+ name="BetaWithSoftplusConcentration"):
parameters = locals()
- parameters.pop("self")
- with ops.name_scope(name, values=[a, b]) as ns:
- super(BetaWithSoftplusAB, self).__init__(
- a=nn.softplus(a, name="softplus_a"),
- b=nn.softplus(b, name="softplus_b"),
+ with ops.name_scope(name, values=[concentration1,
+ concentration0]) as ns:
+ super(BetaWithSoftplusConcentration, self).__init__(
+ concentration1=nn.softplus(concentration1,
+ name="softplus_concentration1"),
+ concentration0=nn.softplus(concentration0,
+ name="softplus_concentration0"),
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
name=ns)
@@ -301,7 +340,7 @@ class BetaWithSoftplusAB(Beta):
@kullback_leibler.RegisterKL(Beta, Beta)
def _kl_beta_beta(d1, d2, name=None):
- """Calculate the batched KL divergence KL(d1 || d2) with d1 and d2 Beta.
+ """Calculate the batchwise KL divergence KL(d1 || d2) with d1 and d2 Beta.
Args:
d1: instance of a Beta distribution object.
@@ -312,14 +351,20 @@ def _kl_beta_beta(d1, d2, name=None):
Returns:
Batchwise KL(d1 || d2)
"""
- inputs = [d1.a, d1.b, d1.a_b_sum, d2.a_b_sum]
- with ops.name_scope(name, "kl_beta_beta", inputs):
- # ln(B(a', b') / B(a, b))
- log_betas = (math_ops.lgamma(d2.a) + math_ops.lgamma(d2.b)
- - math_ops.lgamma(d2.a_b_sum) + math_ops.lgamma(d1.a_b_sum)
- - math_ops.lgamma(d1.a) - math_ops.lgamma(d1.b))
- # (a - a')*psi(a) + (b - b')*psi(b) + (a' - a + b' - b)*psi(a + b)
- digammas = ((d1.a - d2.a) * math_ops.digamma(d1.a)
- + (d1.b - d2.b) * math_ops.digamma(d1.b)
- + (d2.a_b_sum - d1.a_b_sum) * math_ops.digamma(d1.a_b_sum))
- return log_betas + digammas
+ def delta(fn, is_property=True):
+ fn1 = getattr(d1, fn)
+ fn2 = getattr(d2, fn)
+ return (fn2 - fn1) if is_property else (fn2() - fn1())
+ with ops.name_scope(name, "kl_beta_beta", values=[
+ d1.concentration1,
+ d1.concentration0,
+ d1.total_concentration,
+ d2.concentration1,
+ d2.concentration0,
+ d2.total_concentration,
+ ]):
+ return (delta("_log_normalization", is_property=False)
+ - math_ops.digamma(d1.concentration1) * delta("concentration1")
+ - math_ops.digamma(d1.concentration0) * delta("concentration0")
+ + (math_ops.digamma(d1.total_concentration)
+ * delta("total_concentration")))
diff --git a/tensorflow/contrib/distributions/python/ops/chi2.py b/tensorflow/contrib/distributions/python/ops/chi2.py
index 210cbe95c8..870646e26f 100644
--- a/tensorflow/contrib/distributions/python/ops/chi2.py
+++ b/tensorflow/contrib/distributions/python/ops/chi2.py
@@ -25,15 +25,40 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
+__all__ = [
+ "Chi2",
+ "Chi2WithAbsDf",
+]
+
+
class Chi2(gamma.Gamma):
- """The Chi2 distribution with degrees of freedom df.
+ """Chi2 distribution.
+
+ The Chi2 distribution is defined over positive real numbers using a degrees of
+ freedom ("df") parameter.
+
+ #### Mathematical Details
+
+ The probability density function (pdf) is,
+
+ ```none
+ pdf(x; df, x > 0) = x**(0.5 df - 1) exp(-0.5 x) / Z
+ Z = 2**(0.5 df) Gamma(0.5 df)
+ ```
+
+ where:
+
+ * `df` denotes the degrees of freedom,
+ * `Z` is the normalization constant, and,
+ * `Gamma` is the [gamma function](
+ https://en.wikipedia.org/wiki/Gamma_function).
- The PDF of this distribution is:
+ The Chi2 distribution is a special case of the Gamma distribution, i.e.,
- ```pdf(x) = (x^(df/2 - 1)e^(-x/2))/(2^(df/2)Gamma(df/2)), x > 0```
+ ```python
+ Chi2(df) = Gamma(concentration=0.5 * df, rate=0.5)
+ ```
- Note that the Chi2 distribution is a special case of the Gamma distribution,
- with Chi2(df) = Gamma(df/2, 1/2).
"""
def __init__(self,
@@ -46,15 +71,15 @@ class Chi2(gamma.Gamma):
Args:
df: Floating point tensor, the degrees of freedom of the
distribution(s). `df` must contain only positive values.
- validate_args: `Boolean`, default `False`. Whether to assert that
- `df > 0`, and that `x > 0` in the methods `prob(x)` and `log_prob(x)`.
- If `validate_args` is `False` and the inputs are invalid, correct
- behavior is not guaranteed.
- allow_nan_stats: `Boolean`, default `True`. If `False`, raise an
- exception if a statistic (e.g. mean/mode/etc...) is undefined for any
- batch member. If `True`, batch members with valid parameters leading to
- undefined statistics will return NaN for this statistic.
- name: The name to prepend to all ops created by this distribution.
+ validate_args: Python `Boolean`, default `False`. When `True` distribution
+ parameters are checked for validity despite possibly degrading runtime
+ performance. When `False` invalid inputs may silently render incorrect
+ outputs.
+ allow_nan_stats: Python `Boolean`, default `True`. When `True`, statistics
+ (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
+ result is undefined. When `False`, an exception is raised if one or
+ more of the statistic's batch members are undefined.
+ name: `String` name prefixed to Ops created by this class.
"""
parameters = locals()
parameters.pop("self")
@@ -65,8 +90,8 @@ class Chi2(gamma.Gamma):
with ops.name_scope(name, values=[df]) as ns:
self._df = ops.convert_to_tensor(df, name="df")
super(Chi2, self).__init__(
- alpha=0.5 * self._df,
- beta=constant_op.constant(0.5, dtype=self._df.dtype),
+ concentration=0.5 * self._df,
+ rate=constant_op.constant(0.5, dtype=self._df.dtype),
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
name=ns)
@@ -93,8 +118,9 @@ class Chi2WithAbsDf(Chi2):
parameters.pop("self")
with ops.name_scope(name, values=[df]) as ns:
super(Chi2WithAbsDf, self).__init__(
- df=math_ops.floor(math_ops.abs(df, name="abs_df"),
- name="floor_abs_df"),
+ df=math_ops.floor(
+ math_ops.abs(df, name="abs_df"),
+ name="floor_abs_df"),
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
name=ns)
diff --git a/tensorflow/contrib/distributions/python/ops/dirichlet.py b/tensorflow/contrib/distributions/python/ops/dirichlet.py
index 2c39c73195..30d8d136ec 100644
--- a/tensorflow/contrib/distributions/python/ops/dirichlet.py
+++ b/tensorflow/contrib/distributions/python/ops/dirichlet.py
@@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
"""The Dirichlet distribution class."""
+
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -30,189 +31,202 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops import special_math_ops
-_dirichlet_prob_note = """
-Note that the input must be a non-negative tensor with dtype `dtype` and whose
-shape can be broadcast with `self.alpha`. For fixed leading dimensions, the
-last dimension represents counts for the corresponding Dirichlet distribution
-in `self.alpha`. `x` is only legal if it sums up to one.
-"""
+__all__ = [
+ "Dirichlet",
+]
+
+
+_dirichlet_sample_note = """Note: `value` must be a non-negative tensor with
+dtype `self.dtype` and be in the `(self.event_shape() - 1)`-simplex, i.e.,
+`tf.reduce_sum(value, -1) = 1`. It must have a shape compatible with
+`self.batch_shape() + self.event_shape()`."""
class Dirichlet(distribution.Distribution):
"""Dirichlet distribution.
- This distribution is parameterized by a vector `alpha` of concentration
- parameters for `k` classes.
+ The Dirichlet distribution is defined over the
+ [`(k-1)`-simplex](https://en.wikipedia.org/wiki/Simplex) using a positive,
+ length-`k` vector `concentration` (`k > 1`). The Dirichlet is identically the
+ Beta distribution when `k = 2`.
- #### Mathematical details
+ #### Mathematical Details
- The Dirichlet is a distribution over the standard n-simplex, where the
- standard n-simplex is defined by:
- ```{ (x_1, ..., x_n) in R^(n+1) | sum_j x_j = 1 and x_j >= 0 for all j }```.
- The distribution has hyperparameters `alpha = (alpha_1,...,alpha_k)`,
- and probability mass function (prob):
+ The Dirichlet is a distribution over the open `(k-1)`-simplex, i.e.,
- ```prob(x) = 1 / Beta(alpha) * prod_j x_j^(alpha_j - 1)```
+ ```none
+ S^{k-1} = { (x_0, ..., x_{k-1}) in R^k : sum_j x_j = 1 and all_j x_j > 0 }.
+ ```
+
+ The probability density function (pdf) is,
- where `Beta(x) = prod_j Gamma(x_j) / Gamma(sum_j x_j)` is the multivariate
- beta function.
+ ```none
+ pdf(x; alpha) = prod_j x_j**(alpha_j - 1) / Z
+ Z = prod_j Gamma(alpha_j) / Gamma(sum_j alpha_j)
+ ```
+
+ where:
+
+ * `x in S^{k-1}`, i.e., the `(k-1)`-simplex,
+ * `concentration = alpha = [alpha_0, ..., alpha_{k-1}]`, `alpha_j > 0`,
+ * `Z` is the normalization constant aka the [multivariate beta function](
+ https://en.wikipedia.org/wiki/Beta_function#Multivariate_beta_function),
+ and,
+ * `Gamma` is the [gamma function](
+ https://en.wikipedia.org/wiki/Gamma_function).
+
+ The `concentration` represents mean total counts of class occurrence, i.e.,
+
+ ```none
+ concentration = alpha = mean * total_concentration
+ ```
+ where `mean` in `S^{k-1}` and `total_concentration` is a positive real number
+ representing a mean total count.
- This class provides methods to create indexed batches of Dirichlet
- distributions. If the provided `alpha` is rank 2 or higher, for
- every fixed set of leading dimensions, the last dimension represents one
- single Dirichlet distribution. When calling distribution
- functions (e.g. `dist.prob(x)`), `alpha` and `x` are broadcast to the
- same shape (if possible). In all cases, the last dimension of alpha/x
- represents single Dirichlet distributions.
+ Distribution parameters are automatically broadcast in all functions; see
+ examples for details.
#### Examples
```python
- alpha = [1, 2, 3]
+ # Create a single trivariate Dirichlet, with the 3rd class being three times
+ # more frequent than the first. I.e., batch_shape=[], event_shape=[3].
+ alpha = [1., 2, 3]
dist = Dirichlet(alpha)
- ```
- Creates a 3-class distribution, with the 3rd class is most likely to be drawn.
- The distribution functions can be evaluated on x.
+ dist.sample([4, 5]) # shape: [4, 5, 3]
- ```python
- # x same shape as alpha.
- x = [.2, .3, .5]
- dist.prob(x) # Shape []
+ # x has one sample, one batch, three classes:
+ x = [.2, .3, .5] # shape: [3]
+ dist.prob(x) # shape: []
- # alpha will be broadcast to [[1, 2, 3], [1, 2, 3]] to match x.
- x = [[.1, .4, .5], [.2, .3, .5]]
- dist.prob(x) # Shape [2]
+ # x has two samples from one batch:
+ x = [[.1, .4, .5],
+ [.2, .3, .5]]
+ dist.prob(x) # shape: [2]
# alpha will be broadcast to shape [5, 7, 3] to match x.
- x = [[...]] # Shape [5, 7, 3]
- dist.prob(x) # Shape [5, 7]
+ x = [[...]] # shape: [5, 7, 3]
+ dist.prob(x) # shape: [5, 7]
```
- Creates a 2-batch of 3-class distributions.
-
```python
- alpha = [[1, 2, 3], [4, 5, 6]] # Shape [2, 3]
+ # Create batch_shape=[2], event_shape=[3]:
+ alpha = [[1., 2, 3],
+ [4, 5, 6]] # shape: [2, 3]
dist = Dirichlet(alpha)
- # x will be broadcast to [[2, 1, 0], [2, 1, 0]] to match alpha.
+ dist.sample([4, 5]) # shape: [4, 5, 2, 3]
+
x = [.2, .3, .5]
- dist.prob(x) # Shape [2]
+ # x will be broadcast as [[.2, .3, .5],
+ # [.2, .3, .5]],
+ # thus matching batch_shape [2, 3].
+ dist.prob(x) # shape: [2]
```
"""
def __init__(self,
- alpha,
+ concentration,
validate_args=False,
allow_nan_stats=True,
name="Dirichlet"):
"""Initialize a batch of Dirichlet distributions.
Args:
- alpha: Positive floating point tensor with shape broadcastable to
- `[N1,..., Nm, k]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm`
- different `k` class Dirichlet distributions.
- validate_args: `Boolean`, default `False`. Whether to assert valid values
- for parameters `alpha` and `x` in `prob` and `log_prob`. If `False`,
- correct behavior is not guaranteed.
- allow_nan_stats: `Boolean`, default `True`. If `False`, raise an
- exception if a statistic (e.g. mean/mode/etc...) is undefined for any
- batch member. If `True`, batch members with valid parameters leading to
- undefined statistics will return NaN for this statistic.
- name: The name to prefix Ops created by this distribution class.
-
- Examples:
-
- ```python
- # Define 1-batch of 2-class Dirichlet distributions,
- # also known as a Beta distribution.
- dist = Dirichlet([1.1, 2.0])
-
- # Define a 2-batch of 3-class distributions.
- dist = Dirichlet([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
- ```
-
+ concentration: Positive floating-point `Tensor` indicating mean number
+ of class occurrences; aka "alpha". Implies `self.dtype`, and
+ `self.batch_shape`, `self.event_shape`, i.e., if
+ `concentration.shape = [N1, N2, ..., Nm, k]` then
+ `batch_shape = [N1, N2, ..., Nm]` and
+ `event_shape = [k]`.
+ validate_args: Python `Boolean`, default `False`. When `True` distribution
+ parameters are checked for validity despite possibly degrading runtime
+ performance. When `False` invalid inputs may silently render incorrect
+ outputs.
+ allow_nan_stats: Python `Boolean`, default `True`. When `True`, statistics
+ (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
+ result is undefined. When `False`, an exception is raised if one or
+ more of the statistic's batch members are undefined.
+ name: `String` name prefixed to Ops created by this class.
"""
parameters = locals()
parameters.pop("self")
- with ops.name_scope(name, values=[alpha]) as ns:
- alpha = ops.convert_to_tensor(alpha, name="alpha")
- with ops.control_dependencies([
- check_ops.assert_positive(alpha),
- check_ops.assert_rank_at_least(alpha, 1)
- ] if validate_args else []):
- self._alpha = array_ops.identity(alpha, name="alpha")
- self._alpha_sum = math_ops.reduce_sum(alpha,
- reduction_indices=[-1],
- keep_dims=False)
+ with ops.name_scope(name, values=[concentration]) as ns:
+ self._concentration = self._maybe_assert_valid_concentration(
+ ops.convert_to_tensor(concentration, name="concentration"),
+ validate_args)
+ self._total_concentration = math_ops.reduce_sum(self._concentration, -1)
super(Dirichlet, self).__init__(
- dtype=self._alpha.dtype,
+ dtype=self._concentration.dtype,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
is_continuous=True,
reparameterization_type=distribution.NOT_REPARAMETERIZED,
parameters=parameters,
- graph_parents=[self._alpha, self._alpha_sum],
+ graph_parents=[self._concentration,
+ self._total_concentration],
name=ns)
@property
- def alpha(self):
- """Shape parameter."""
- return self._alpha
+ def concentration(self):
+ """Concentration parameter; expected counts for that coordinate."""
+ return self._concentration
@property
- def alpha_sum(self):
- """Sum of shape parameter."""
- return self._alpha_sum
+ def total_concentration(self):
+ """Sum of last dim of concentration parameter."""
+ return self._total_concentration
def _batch_shape_tensor(self):
- return array_ops.shape(self.alpha_sum)
+ return array_ops.shape(self.total_concentration)
def _batch_shape(self):
- return self.alpha_sum.get_shape()
+ return self.total_concentration.get_shape()
def _event_shape_tensor(self):
- return array_ops.gather(array_ops.shape(self.alpha),
- [array_ops.rank(self.alpha) - 1])
+ return array_ops.shape(self.concentration)[-1:]
def _event_shape(self):
- return self.alpha.get_shape().with_rank_at_least(1)[-1:]
+ return self.concentration.get_shape().with_rank_at_least(1)[-1:]
def _sample_n(self, n, seed=None):
gamma_sample = random_ops.random_gamma(
- [n,], self.alpha, dtype=self.dtype, seed=seed)
- return gamma_sample / math_ops.reduce_sum(
- gamma_sample, reduction_indices=[-1], keep_dims=True)
+ shape=[n],
+ alpha=self.concentration,
+ dtype=self.dtype,
+ seed=seed)
+ return gamma_sample / math_ops.reduce_sum(gamma_sample, -1, keep_dims=True)
- @distribution_util.AppendDocstring(_dirichlet_prob_note)
+ @distribution_util.AppendDocstring(_dirichlet_sample_note)
def _log_prob(self, x):
- x = ops.convert_to_tensor(x, name="x")
- x = self._assert_valid_sample(x)
- unnorm_prob = (self.alpha - 1.) * math_ops.log(x)
- log_prob = math_ops.reduce_sum(
- unnorm_prob, reduction_indices=[-1],
- keep_dims=False) - special_math_ops.lbeta(self.alpha)
- return log_prob
-
- @distribution_util.AppendDocstring(_dirichlet_prob_note)
+ return self._log_unnormalized_prob(x) - self._log_normalization()
+
+ @distribution_util.AppendDocstring(_dirichlet_sample_note)
def _prob(self, x):
return math_ops.exp(self._log_prob(x))
+ def _log_unnormalized_prob(self, x):
+ x = self._maybe_assert_valid_sample(x)
+ return math_ops.reduce_sum((self.concentration - 1.) * math_ops.log(x), -1)
+
+ def _log_normalization(self):
+ return special_math_ops.lbeta(self.concentration)
+
def _entropy(self):
- entropy = special_math_ops.lbeta(self.alpha)
- entropy += math_ops.digamma(self.alpha_sum) * (
- self.alpha_sum
- - math_ops.cast(self.event_shape_tensor()[0], self.dtype))
- entropy += -math_ops.reduce_sum(
- (self.alpha - 1.) * math_ops.digamma(self.alpha),
- reduction_indices=[-1],
- keep_dims=False)
- return entropy
+ k = math_ops.cast(self.event_shape_tensor()[0], self.dtype)
+ return (
+ self._log_normalization()
+ + ((self.total_concentration - k)
+ * math_ops.digamma(self.total_concentration))
+ - math_ops.reduce_sum(
+ (self.concentration - 1.) * math_ops.digamma(self.concentration),
+ axis=-1))
def _mean(self):
- return self.alpha / self.alpha_sum[..., None]
+ return self.concentration / self.total_concentration[..., None]
def _covariance(self):
x = self._variance_scale_term() * self._mean()
@@ -227,37 +241,57 @@ class Dirichlet(distribution.Distribution):
def _variance_scale_term(self):
"""Helper to `_covariance` and `_variance` which computes a shared scale."""
- return math_ops.rsqrt(1. + self.alpha_sum[..., None])
+ return math_ops.rsqrt(1. + self.total_concentration[..., None])
@distribution_util.AppendDocstring(
- """Note that the mode for the Dirichlet distribution is only defined
- when `alpha > 1`. This returns the mode when `alpha > 1`,
- and NaN otherwise. If `self.allow_nan_stats` is `False`, an exception
- will be raised rather than returning `NaN`.""")
+ """Note: The mode is undefined when any `concentration <= 1`. If
+ `self.allow_nan_stats` is `True`, `NaN` is used for undefined modes. If
+ `self.allow_nan_stats` is `False` an exception is raised when one or more
+ modes are undefined.""")
def _mode(self):
- mode = ((self.alpha - 1.) /
- (array_ops.expand_dims(self.alpha_sum, dim=-1) -
- math_ops.cast(self.event_shape_tensor()[0], self.dtype)))
+ k = math_ops.cast(self.event_shape_tensor()[0], self.dtype)
+ mode = (self.concentration - 1.) / (self.total_concentration[..., None] - k)
if self.allow_nan_stats:
- nan = np.array(np.nan, dtype=self.dtype.as_numpy_dtype())
- shape = array_ops.concat([self.batch_shape_tensor(),
- self.event_shape_tensor()], 0)
+ nan = array_ops.fill(
+ array_ops.shape(mode),
+ np.array(np.nan, dtype=self.dtype.as_numpy_dtype()),
+ name="nan")
return array_ops.where(
- math_ops.greater(self.alpha, 1.),
- mode,
- array_ops.fill(shape, nan, name="nan"))
- else:
- return control_flow_ops.with_dependencies([
- check_ops.assert_less(
- array_ops.ones((), dtype=self.dtype), self.alpha,
- message="mode not defined for components of alpha <= 1")
- ], mode)
-
- def _assert_valid_sample(self, x):
- if not self.validate_args: return x
+ math_ops.reduce_all(self.concentration > 1., axis=-1),
+ mode, nan)
+ return control_flow_ops.with_dependencies([
+ check_ops.assert_less(
+ array_ops.ones([], self.dtype),
+ self.concentration,
+ message="Mode undefined when any concentration <= 1"),
+ ], mode)
+
+ def _maybe_assert_valid_concentration(self, concentration, validate_args):
+ """Checks the validity of the concentration parameter."""
+ if not validate_args:
+ return concentration
+ return control_flow_ops.with_dependencies([
+ check_ops.assert_positive(
+ concentration,
+ message="Concentration parameter must be positive."),
+ check_ops.assert_rank_at_least(
+ concentration, 1,
+ message="Concentration parameter must have >=1 dimensions."),
+ check_ops.assert_less(
+ 1, array_ops.shape(concentration)[-1],
+ message="Concentration parameter must have event_size >= 2."),
+ ], concentration)
+
+ def _maybe_assert_valid_sample(self, x):
+ """Checks the validity of a sample."""
+ if not self.validate_args:
+ return x
return control_flow_ops.with_dependencies([
- check_ops.assert_positive(x),
+ check_ops.assert_positive(
+ x,
+ message="samples must be positive"),
distribution_util.assert_close(
array_ops.ones((), dtype=self.dtype),
- math_ops.reduce_sum(x, reduction_indices=[-1])),
+ math_ops.reduce_sum(x, -1),
+ message="sample last-dimension must sum to `1`"),
], x)
diff --git a/tensorflow/contrib/distributions/python/ops/dirichlet_multinomial.py b/tensorflow/contrib/distributions/python/ops/dirichlet_multinomial.py
index 84c11d3801..0ba5165759 100644
--- a/tensorflow/contrib/distributions/python/ops/dirichlet_multinomial.py
+++ b/tensorflow/contrib/distributions/python/ops/dirichlet_multinomial.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""The Dirichlet Multinomial distribution class."""
+"""The DirichletMultinomial distribution class."""
from __future__ import absolute_import
from __future__ import division
@@ -30,52 +30,76 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops import special_math_ops
-_dirichlet_multinomial_prob_note = """
-For each batch of counts `[n_1,...,n_k]`, `P[counts]` is the probability
-that after sampling `n` draws from this Dirichlet Multinomial
-distribution, the number of draws falling in class `j` is `n_j`. Note that
-different sequences of draws can result in the same counts, thus the
-probability includes a combinatorial coefficient.
+__all__ = [
+ "DirichletMultinomial",
+]
-Note that input, "counts", must be a non-negative tensor with dtype `dtype`
-and whose shape can be broadcast with `self.alpha`. For fixed leading
-dimensions, the last dimension represents counts for the corresponding
-Dirichlet Multinomial distribution in `self.alpha`. `counts` is only legal if
-it sums up to `n` and its components are equal to integer values.
-"""
+
+_dirichlet_multinomial_sample_note = """For each batch of counts,
+`value = [n_0, ... ,n_{k-1}]`, `P[value]` is the probability that after sampling
+`self.total_count` draws from this Dirichlet-Multinomial distribution, the
+number of draws falling in class `j` is `n_j`. Since this definition is
+[exchangeable]( https://en.wikipedia.org/wiki/Exchangeable_random_variables);
+different sequences have the same counts so the probability includes a
+combinatorial coefficient.
+
+Note: `value` must be a non-negative tensor with dtype `self.dtype`, have no
+fractional components, and such that
+`tf.reduce_sum(value, -1) = self.total_count`. Its shape must be broadcastable
+with `self.concentration` and `self.total_count`."""
class DirichletMultinomial(distribution.Distribution):
- """DirichletMultinomial mixture distribution.
+ """Dirichlet-Multinomial compound distribution.
+
+ The Dirichlet-Multinomial distribution is parameterized by a (batch of)
+ length-`k` `concentration` vectors (`k > 1`) and a `total_count` number of
+ trials, i.e., the number of trials per draw from the DirichletMultinomial. It
+ is defined over a (batch of) length-`k` vector `counts` such that
+ `tf.reduce_sum(counts, -1) = total_count`. The Dirichlet-Multinomial is
+ identically the Beta-Binomial distribution when `k = 2`.
- This distribution is parameterized by a vector `alpha` of concentration
- parameters for `k` classes and `n`, the counts per each class..
+ #### Mathematical Details
- #### Mathematical details
+ The Dirichlet-Multinomial is a distribution over `k`-class counts, i.e., a
+ length-`k` vector of non-negative integer `counts = n = [n_0, ..., n_{k-1}]`.
+
+ The probability mass function (pmf) is,
+
+ ```none
+ pmf(n; alpha, N) = Beta(alpha + n) / (prod_j n_j!) / Z
+ Z = Beta(alpha) / N!
+ ```
- The Dirichlet Multinomial is a distribution over k-class count data, meaning
- for each k-tuple of non-negative integer `counts = [c_1,...,c_k]`, we have a
- probability of these draws being made from the distribution. The distribution
- has hyperparameters `alpha = (alpha_1,...,alpha_k)`, and probability mass
- function (pmf):
+ where:
- ```pmf(counts) = N! / (n_1!...n_k!) * Beta(alpha + c) / Beta(alpha)```
+ * `concentration = alpha = [alpha_0, ..., alpha_{k-1}]`, `alpha_j > 0`,
+ * `total_count = N`, `N` a positive integer,
+ * `N!` is `N` factorial, and,
+ * `Beta(x) = prod_j Gamma(x_j) / Gamma(sum_j x_j)` is the
+ [multivariate beta function](
+ https://en.wikipedia.org/wiki/Beta_function#Multivariate_beta_function),
+ and,
+ * `Gamma` is the [gamma function](
+ https://en.wikipedia.org/wiki/Gamma_function).
- where above `N = sum_j n_j`, `N!` is `N` factorial, and
- `Beta(x) = prod_j Gamma(x_j) / Gamma(sum_j x_j)` is the multivariate beta
- function.
+ Dirichlet-Multinomial is a [compound distribution](
+ https://en.wikipedia.org/wiki/Compound_probability_distribution), i.e., its
+ samples are generated as follows.
- This is a mixture distribution in that `M` samples can be produced by:
- 1. Choose class probabilities `p = (p_1,...,p_k) ~ Dir(alpha)`
- 2. Draw integers `m = (n_1,...,n_k) ~ Multinomial(N, p)`
+ 1. Choose class probabilities:
+ `probs = [p_0,...,p_{k-1}] ~ Dir(concentration)`
+ 2. Draw integers:
+ `counts = [n_0,...,n_{k-1}] ~ Multinomial(total_count, probs)`
- This class provides methods to create indexed batches of Dirichlet
- Multinomial distributions. If the provided `alpha` is rank 2 or higher, for
- every fixed set of leading dimensions, the last dimension represents one
- single Dirichlet Multinomial distribution. When calling distribution
- functions (e.g. `dist.prob(counts)`), `alpha` and `counts` are broadcast to
- the same shape (if possible). In all cases, the last dimension of
- alpha/counts represents single Dirichlet Multinomial distributions.
+ The last `concentration` dimension parametrizes a single Dirichlet-Multinomial
+ distribution. When calling distribution functions (e.g., `dist.prob(counts)`),
+ `concentration`, `total_count` and `counts` are broadcast to the same shape.
+ The last dimension of of `counts` corresponds single Dirichlet-Multinomial
+ distributions.
+
+ Distribution parameters are automatically broadcast in all functions; see
+ examples for details.
#### Examples
@@ -116,116 +140,102 @@ class DirichletMultinomial(distribution.Distribution):
"""
- # TODO(b/27419586) Change docstring for dtype of alpha once int allowed.
+ # TODO(b/27419586) Change docstring for dtype of concentration once int
+ # allowed.
def __init__(self,
- n,
- alpha,
+ total_count,
+ concentration,
validate_args=False,
allow_nan_stats=True,
name="DirichletMultinomial"):
"""Initialize a batch of DirichletMultinomial distributions.
Args:
- n: Non-negative floating point tensor, whose dtype is the same as
- `alpha`. The shape is broadcastable to `[N1,..., Nm]` with `m >= 0`.
- Defines this as a batch of `N1 x ... x Nm` different Dirichlet
- multinomial distributions. Its components should be equal to integer
- values.
- alpha: Positive floating point tensor, whose dtype is the same as
- `n` with shape broadcastable to `[N1,..., Nm, k]` `m >= 0`. Defines
- this as a batch of `N1 x ... x Nm` different `k` class Dirichlet
+ total_count: Non-negative floating point tensor, whose dtype is the same
+ as `concentration`. The shape is broadcastable to `[N1,..., Nm]` with
+ `m >= 0`. Defines this as a batch of `N1 x ... x Nm` different
+ Dirichlet multinomial distributions. Its components should be equal to
+ integer values.
+ concentration: Positive floating point tensor, whose dtype is the
+ same as `n` with shape broadcastable to `[N1,..., Nm, k]` `m >= 0`.
+ Defines this as a batch of `N1 x ... x Nm` different `k` class Dirichlet
multinomial distributions.
- validate_args: `Boolean`, default `False`. Whether to assert valid
- values for parameters `alpha` and `n`, and `x` in `prob` and
- `log_prob`. If `False`, correct behavior is not guaranteed.
- allow_nan_stats: `Boolean`, default `True`. If `False`, raise an
- exception if a statistic (e.g. mean/mode/etc...) is undefined for any
- batch member. If `True`, batch members with valid parameters leading to
- undefined statistics will return NaN for this statistic.
- name: The name to prefix Ops created by this distribution class.
-
- Examples:
-
- ```python
- # Define 1-batch of 2-class Dirichlet multinomial distribution,
- # also known as a beta-binomial.
- dist = DirichletMultinomial(2.0, [1.1, 2.0])
-
- # Define a 2-batch of 3-class distributions.
- dist = DirichletMultinomial([3., 4], [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
- ```
-
+ validate_args: Python `Boolean`, default `False`. When `True` distribution
+ parameters are checked for validity despite possibly degrading runtime
+ performance. When `False` invalid inputs may silently render incorrect
+ outputs.
+ allow_nan_stats: Python `Boolean`, default `True`. When `True`, statistics
+ (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
+ result is undefined. When `False`, an exception is raised if one or
+ more of the statistic's batch members are undefined.
+ name: `String` name prefixed to Ops created by this class.
"""
parameters = locals()
parameters.pop("self")
- with ops.name_scope(name, values=[n, alpha]) as ns:
+ with ops.name_scope(name, values=[total_count, concentration]) as ns:
# Broadcasting works because:
# * The broadcasting convention is to prepend dimensions of size [1], and
- # we use the last dimension for the distribution, wherease
+ # we use the last dimension for the distribution, whereas
# the batch dimensions are the leading dimensions, which forces the
# distribution dimension to be defined explicitly (i.e. it cannot be
# created automatically by prepending). This forces enough
- # explicitivity.
- # * All calls involving `counts` eventually require a broadcast between
- # `counts` and alpha.
- self._alpha = self._assert_valid_alpha(alpha, validate_args)
- self._n = self._assert_valid_n(n, validate_args)
- self._alpha_sum = math_ops.reduce_sum(
- self._alpha, reduction_indices=[-1], keep_dims=False)
+ # explicitness.
+ # * All calls involving `counts` eventually require a broadcast between
+ # `counts` and concentration.
+ self._total_count = self._maybe_assert_valid_total_count(
+ ops.convert_to_tensor(total_count, name="total_count"),
+ validate_args)
+ self._concentration = self._maybe_assert_valid_concentration(
+ ops.convert_to_tensor(concentration,
+ name="concentration"),
+ validate_args)
+ self._total_concentration = math_ops.reduce_sum(self._concentration, -1)
super(DirichletMultinomial, self).__init__(
- dtype=self._alpha.dtype,
- is_continuous=False,
- reparameterization_type=distribution.NOT_REPARAMETERIZED,
+ dtype=self._concentration.dtype,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
+ is_continuous=False,
+ reparameterization_type=distribution.NOT_REPARAMETERIZED,
parameters=parameters,
- graph_parents=[self._alpha, self._n, self._alpha_sum],
+ graph_parents=[self._total_count,
+ self._concentration],
name=ns)
@property
- def n(self):
- """Parameter defining this distribution."""
- return self._n
+ def total_count(self):
+ """Number of trials used to construct a sample."""
+ return self._total_count
@property
- def alpha(self):
- """Parameter defining this distribution."""
- return self._alpha
+ def concentration(self):
+ """Concentration parameter; expected prior counts for that coordinate."""
+ return self._concentration
@property
- def alpha_sum(self):
- """Summation of alpha parameter."""
- return self._alpha_sum
+ def total_concentration(self):
+ """Sum of last dim of concentration parameter."""
+ return self._total_concentration
def _batch_shape_tensor(self):
- return array_ops.shape(self.alpha_sum)
+ return array_ops.shape(self.total_concentration)
def _batch_shape(self):
- return self.alpha_sum.get_shape()
+ return self.total_concentration.get_shape()
def _event_shape_tensor(self):
- return array_ops.reverse_v2(array_ops.shape(self.alpha), [0])[0]
+ return array_ops.shape(self.concentration)[-1:]
def _event_shape(self):
- # Event shape depends only on alpha, not "n".
- return self.alpha.get_shape().with_rank_at_least(1)[-1:]
+ # Event shape depends only on total_concentration, not "n".
+ return self.concentration.get_shape().with_rank_at_least(1)[-1:]
def _sample_n(self, n, seed=None):
- n_draws = math_ops.cast(self.n, dtype=dtypes.int32)
- if self.n.get_shape().ndims is not None:
- if self.n.get_shape().ndims != 0:
- raise NotImplementedError(
- "Sample only supported for scalar number of draws.")
- elif self.validate_args:
- is_scalar = check_ops.assert_rank(
- n_draws, 0,
- message="Sample only supported for scalar number of draws.")
- n_draws = control_flow_ops.with_dependencies([is_scalar], n_draws)
+ n_draws = math_ops.cast(self.total_count, dtype=dtypes.int32)
k = self.event_shape_tensor()[0]
unnormalized_logits = array_ops.reshape(
math_ops.log(random_ops.random_gamma(
shape=[n],
- alpha=self.alpha,
+ alpha=self.concentration,
dtype=self.dtype,
seed=seed)),
shape=[-1, k])
@@ -233,40 +243,41 @@ class DirichletMultinomial(distribution.Distribution):
logits=unnormalized_logits,
num_samples=n_draws,
seed=distribution_util.gen_new_seed(seed, salt="dirichlet_multinomial"))
- x = math_ops.reduce_sum(array_ops.one_hot(draws, depth=k),
- reduction_indices=-2)
+ x = math_ops.reduce_sum(array_ops.one_hot(draws, depth=k), -2)
final_shape = array_ops.concat([[n], self.batch_shape_tensor(), [k]], 0)
return array_ops.reshape(x, final_shape)
- @distribution_util.AppendDocstring(_dirichlet_multinomial_prob_note)
+ @distribution_util.AppendDocstring(_dirichlet_multinomial_sample_note)
def _log_prob(self, counts):
- counts = self._assert_valid_counts(counts)
- ordered_prob = (special_math_ops.lbeta(self.alpha + counts) -
- special_math_ops.lbeta(self.alpha))
- log_prob = ordered_prob + distribution_util.log_combinations(
- self.n, counts)
- return log_prob
-
- @distribution_util.AppendDocstring(_dirichlet_multinomial_prob_note)
+ counts = self._maybe_assert_valid_sample(counts)
+ ordered_prob = (
+ special_math_ops.lbeta(self.concentration + counts)
+ - special_math_ops.lbeta(self.concentration))
+ return ordered_prob + distribution_util.log_combinations(
+ self.total_count, counts)
+
+ @distribution_util.AppendDocstring(_dirichlet_multinomial_sample_note)
def _prob(self, counts):
return math_ops.exp(self._log_prob(counts))
def _mean(self):
- return self.n * (self.alpha / self.alpha_sum[..., None])
+ return self.total_count * (self.concentration /
+ self.total_concentration[..., None])
@distribution_util.AppendDocstring(
"""The covariance for each batch member is defined as the following:
- ```
+ ```none
Var(X_j) = n * alpha_j / alpha_0 * (1 - alpha_j / alpha_0) *
(n + alpha_0) / (1 + alpha_0)
```
- where `alpha_0 = sum_j alpha_j`.
+ where `concentration = alpha` and
+ `total_concentration = alpha_0 = sum_j alpha_j`.
The covariance between elements in a batch is defined as:
- ```
+ ```none
Cov(X_i, X_j) = -n * alpha_i * alpha_j / alpha_0 ** 2 *
(n + alpha_0) / (1 + alpha_0)
```
@@ -280,40 +291,55 @@ class DirichletMultinomial(distribution.Distribution):
def _variance(self):
scale = self._variance_scale_term()
x = scale * self._mean()
- return x * (self.n * scale - x)
+ return x * (self.total_count * scale - x)
def _variance_scale_term(self):
"""Helper to `_covariance` and `_variance` which computes a shared scale."""
# We must take care to expand back the last dim whenever we use the
- # alpha_sum.
- c0 = self.alpha_sum[..., None]
- return math_ops.sqrt((1. + c0 / self.n) / (1. + c0))
+ # total_concentration.
+ c0 = self.total_concentration[..., None]
+ return math_ops.sqrt((1. + c0 / self.total_count) / (1. + c0))
- def _assert_valid_counts(self, counts):
+ def _maybe_assert_valid_concentration(self, concentration, validate_args):
+ """Checks the validity of the concentration parameter."""
+ if not validate_args:
+ return concentration
+ return control_flow_ops.with_dependencies([
+ check_ops.assert_positive(
+ concentration,
+ message="Concentration parameter must be positive."),
+ check_ops.assert_rank_at_least(
+ concentration, 1,
+ message="Concentration parameter must have >=1 dimensions."),
+ check_ops.assert_less(
+ 1, array_ops.shape(concentration)[-1],
+ message="Concentration parameter must have event_size >= 2."),
+ ], concentration)
+
+ def _maybe_assert_valid_total_count(self, total_count, validate_args):
+ if not validate_args:
+ return total_count
+ return control_flow_ops.with_dependencies([
+ check_ops.assert_non_negative(
+ total_count,
+ message="total_count must be non-negative."),
+ distribution_util.assert_integer_form(
+ total_count,
+ message="total_count cannot contain fractional values."),
+ ], total_count)
+
+ def _maybe_assert_valid_sample(self, counts):
"""Check counts for proper shape, values, then return tensor version."""
- counts = ops.convert_to_tensor(counts, name="counts")
if not self.validate_args:
return counts
- candidate_n = math_ops.reduce_sum(counts, reduction_indices=[-1])
return control_flow_ops.with_dependencies([
- check_ops.assert_non_negative(counts),
+ check_ops.assert_non_negative(
+ counts,
+ message="counts must be non-negative."),
check_ops.assert_equal(
- self._n, candidate_n,
- message="counts do not sum to n"),
- distribution_util.assert_integer_form(counts)], counts)
-
- def _assert_valid_alpha(self, alpha, validate_args):
- alpha = ops.convert_to_tensor(alpha, name="alpha")
- if not validate_args:
- return alpha
- return control_flow_ops.with_dependencies(
- [check_ops.assert_rank_at_least(alpha, 1),
- check_ops.assert_positive(alpha)], alpha)
-
- def _assert_valid_n(self, n, validate_args):
- n = ops.convert_to_tensor(n, name="n")
- if not validate_args:
- return n
- return control_flow_ops.with_dependencies(
- [check_ops.assert_non_negative(n),
- distribution_util.assert_integer_form(n)], n)
+ self.total_count, math_ops.reduce_sum(counts, -1),
+ message="counts last-dimension must sum to `self.total_count`"),
+ distribution_util.assert_integer_form(
+ counts,
+ message="counts cannot contain fractional components."),
+ ], counts)
diff --git a/tensorflow/contrib/distributions/python/ops/exponential.py b/tensorflow/contrib/distributions/python/ops/exponential.py
index 79301436c0..23cc63df59 100644
--- a/tensorflow/contrib/distributions/python/ops/exponential.py
+++ b/tensorflow/contrib/distributions/python/ops/exponential.py
@@ -29,36 +29,64 @@ from tensorflow.python.ops import nn
from tensorflow.python.ops import random_ops
+__all__ = [
+ "Exponential",
+ "ExponentialWithSoftplusRate",
+]
+
+
class Exponential(gamma.Gamma):
- """The Exponential distribution with rate parameter lam.
+ """Exponential distribution.
+
+ The Exponential distribution is parameterized by an event `rate` parameter.
+
+ #### Mathematical Details
+
+ The probability density function (pdf) is,
+
+ ```none
+ pdf(x; lambda, x > 0) = exp(-lambda x) / Z
+ Z = 1 / lambda
+ ```
+
+ where `rate = lambda` and `Z` is the normalizaing constant.
+
+ The Exponential distribution is a special case of the Gamma distribution,
+ i.e.,
+
+ ```python
+ Exponential(rate) = Gamma(concentration=1., rate)
+ ```
- The PDF of this distribution is:
+ The Exponential distribution uses a `rate` parameter, or "inverse scale",
+ which can be intuited as,
- ```prob(x) = (lam * e^(-lam * x)), x > 0```
+ ```none
+ X ~ Exponential(rate=1)
+ Y = X / rate
+ ```
- Note that the Exponential distribution is a special case of the Gamma
- distribution, with Exponential(lam) = Gamma(1, lam).
"""
def __init__(self,
- lam,
+ rate,
validate_args=False,
allow_nan_stats=True,
name="Exponential"):
- """Construct Exponential distribution with parameter `lam`.
+ """Construct Exponential distribution with parameter `rate`.
Args:
- lam: Floating point tensor, the rate of the distribution(s).
- `lam` must contain only positive values.
- validate_args: `Boolean`, default `False`. Whether to assert that
- `lam > 0`, and that `x > 0` in the methods `prob(x)` and `log_prob(x)`.
- If `validate_args` is `False` and the inputs are invalid, correct
- behavior is not guaranteed.
- allow_nan_stats: `Boolean`, default `True`. If `False`, raise an
- exception if a statistic (e.g. mean/mode/etc...) is undefined for any
- batch member. If `True`, batch members with valid parameters leading to
- undefined statistics will return NaN for this statistic.
- name: The name to prepend to all ops created by this distribution.
+ rate: Floating point tensor, equivalent to `1 / mean`. Must contain only
+ positive values.
+ validate_args: Python `Boolean`, default `False`. When `True` distribution
+ parameters are checked for validity despite possibly degrading runtime
+ performance. When `False` invalid inputs may silently render incorrect
+ outputs.
+ allow_nan_stats: Python `Boolean`, default `True`. When `True`, statistics
+ (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
+ result is undefined. When `False`, an exception is raised if one or
+ more of the statistic's batch members are undefined.
+ name: `String` name prefixed to Ops created by this class.
"""
parameters = locals()
parameters.pop("self")
@@ -66,30 +94,30 @@ class Exponential(gamma.Gamma):
# true in the parent class "Gamma." Therefore, passing
# allow_nan_stats=True
# through to the parent class results in unnecessary asserts.
- with ops.name_scope(name, values=[lam]) as ns:
- self._lam = ops.convert_to_tensor(lam, name="lam")
+ with ops.name_scope(name, values=[rate]) as ns:
+ self._rate = ops.convert_to_tensor(rate, name="rate")
super(Exponential, self).__init__(
- alpha=array_ops.ones((), dtype=self._lam.dtype),
- beta=self._lam,
+ concentration=array_ops.ones((), dtype=self._rate.dtype),
+ rate=self._rate,
allow_nan_stats=allow_nan_stats,
validate_args=validate_args,
name=ns)
- # While the Gamma distribution is not re-parameterizable, the
- # exponential distribution is.
+ # While the Gamma distribution is not reparameterizable, the exponential
+ # distribution is.
self._reparameterization_type = True
self._parameters = parameters
- self._graph_parents += [self._lam]
+ self._graph_parents += [self._rate]
@staticmethod
def _param_shapes(sample_shape):
- return {"lam": ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)}
+ return {"rate": ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)}
@property
- def lam(self):
- return self._lam
+ def rate(self):
+ return self._rate
def _sample_n(self, n, seed=None):
- shape = array_ops.concat(([n], array_ops.shape(self._lam)), 0)
+ shape = array_ops.concat(([n], array_ops.shape(self._rate)), 0)
# Sample uniformly-at-random from the open-interval (0, 1).
sampled = random_ops.random_uniform(
shape,
@@ -98,22 +126,22 @@ class Exponential(gamma.Gamma):
maxval=array_ops.ones((), dtype=self.dtype),
seed=seed,
dtype=self.dtype)
- return -math_ops.log(sampled) / self._lam
+ return -math_ops.log(sampled) / self._rate
-class ExponentialWithSoftplusLam(Exponential):
- """Exponential with softplus transform on `lam`."""
+class ExponentialWithSoftplusRate(Exponential):
+ """Exponential with softplus transform on `rate`."""
def __init__(self,
- lam,
+ rate,
validate_args=False,
allow_nan_stats=True,
- name="ExponentialWithSoftplusLam"):
+ name="ExponentialWithSoftplusRate"):
parameters = locals()
parameters.pop("self")
- with ops.name_scope(name, values=[lam]) as ns:
- super(ExponentialWithSoftplusLam, self).__init__(
- lam=nn.softplus(lam, name="softplus_lam"),
+ with ops.name_scope(name, values=[rate]) as ns:
+ super(ExponentialWithSoftplusRate, self).__init__(
+ rate=nn.softplus(rate, name="softplus_rate"),
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
name=ns)
diff --git a/tensorflow/contrib/distributions/python/ops/gamma.py b/tensorflow/contrib/distributions/python/ops/gamma.py
index d7fc3a86f6..a8e1173d98 100644
--- a/tensorflow/contrib/distributions/python/ops/gamma.py
+++ b/tensorflow/contrib/distributions/python/ops/gamma.py
@@ -36,107 +36,143 @@ from tensorflow.python.ops import nn
from tensorflow.python.ops import random_ops
+__all__ = [
+ "Gamma",
+ "GammaWithSoftplusConcentrationRate",
+]
+
+
class Gamma(distribution.Distribution):
- """The `Gamma` distribution with parameter alpha and beta.
+ """Gamma distribution.
+
+ The Gamma distribution is defined over positive real numbers using
+ parameters `concentration` (aka "alpha") and `rate` (aka "beta").
+
+ #### Mathematical Details
+
+ The probability density function (pdf) is,
- The parameters are the shape and inverse scale parameters alpha, beta.
+ ```none
+ pdf(x; alpha, beta, x > 0) = x**(alpha - 1) exp(-x beta) / Z
+ Z = Gamma(alpha) beta**alpha
+ ```
+
+ where:
+
+ * `concentration = alpha`, `alpha > 0`,
+ * `rate = beta`, `beta > 0`,
+ * `Z` is the normalizing constant, and,
+ * `Gamma` is the [gamma function](
+ https://en.wikipedia.org/wiki/Gamma_function).
- The PDF of this distribution is:
+ The cumulative density function (cdf) is,
+
+ ```none
+ cdf(x; alpha, beta, x > 0) = GammaInc(alpha, beta x) / Gamma(alpha)
+ ```
- ```pdf(x) = (beta^alpha)(x^(alpha-1))e^(-x*beta)/Gamma(alpha), x > 0```
+ where `GammaInc` is the [lower incomplete Gamma function](
+ https://en.wikipedia.org/wiki/Incomplete_gamma_function).
- and the CDF of this distribution is:
+ The parameters can be intuited via their relationship to mean and stddev,
- ```cdf(x) = GammaInc(alpha, beta * x) / Gamma(alpha), x > 0```
+ ```none
+ concentration = alpha = (mean / stddev)**2
+ rate = beta = mean / stddev**2 = concentration / mean
+ ```
- where GammaInc is the incomplete lower Gamma function.
+ Distribution parameters are automatically broadcast in all functions; see
+ examples for details.
- WARNING: This distribution may draw 0-valued samples for small alpha values.
- See the note on `tf.random_gamma`.
+ WARNING: This distribution may draw 0-valued samples for small `concentration`
+ values. See note in `tf.random_gamma` docstring.
- Examples:
+ #### Examples
```python
- dist = Gamma(alpha=3.0, beta=2.0)
- dist2 = Gamma(alpha=[3.0, 4.0], beta=[2.0, 3.0])
+ dist = Gamma(concentration=3.0, rate=2.0)
+ dist2 = Gamma(concentration=[3.0, 4.0], rate=[2.0, 3.0])
```
"""
def __init__(self,
- alpha,
- beta,
+ concentration,
+ rate,
validate_args=False,
allow_nan_stats=True,
name="Gamma"):
- """Construct Gamma distributions with parameters `alpha` and `beta`.
+ """Construct Gamma with `concentration` and `rate` parameters.
- The parameters `alpha` and `beta` must be shaped in a way that supports
- broadcasting (e.g. `alpha + beta` is a valid operation).
+ The parameters `concentration` and `rate` must be shaped in a way that
+ supports broadcasting (e.g. `concentration + rate` is a valid operation).
Args:
- alpha: Floating point tensor, the shape params of the
- distribution(s).
- alpha must contain only positive values.
- beta: Floating point tensor, the inverse scale params of the
- distribution(s).
- beta must contain only positive values.
- validate_args: `Boolean`, default `False`. Whether to assert that
- `a > 0`, `b > 0`, and that `x > 0` in the methods `prob(x)` and
- `log_prob(x)`. If `validate_args` is `False` and the inputs are
- invalid, correct behavior is not guaranteed.
- allow_nan_stats: `Boolean`, default `True`. If `False`, raise an
- exception if a statistic (e.g. mean/mode/etc...) is undefined for any
- batch member. If `True`, batch members with valid parameters leading to
- undefined statistics will return NaN for this statistic.
- name: The name to prepend to all ops created by this distribution.
+ concentration: Floating point tensor, the concentration params of the
+ distribution(s). Must contain only positive values.
+ rate: Floating point tensor, the inverse scale params of the
+ distribution(s). Must contain only positive values.
+ validate_args: Python `Boolean`, default `False`. When `True` distribution
+ parameters are checked for validity despite possibly degrading runtime
+ performance. When `False` invalid inputs may silently render incorrect
+ outputs.
+ allow_nan_stats: Python `Boolean`, default `True`. When `True`, statistics
+ (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
+ result is undefined. When `False`, an exception is raised if one or
+ more of the statistic's batch members are undefined.
+ name: `String` name prefixed to Ops created by this class.
Raises:
- TypeError: if `alpha` and `beta` are different dtypes.
+ TypeError: if `concentration` and `rate` are different dtypes.
"""
parameters = locals()
parameters.pop("self")
- with ops.name_scope(name, values=[alpha, beta]) as ns:
+ with ops.name_scope(name, values=[concentration, rate]) as ns:
with ops.control_dependencies([
- check_ops.assert_positive(alpha),
- check_ops.assert_positive(beta),
+ check_ops.assert_positive(concentration),
+ check_ops.assert_positive(rate),
] if validate_args else []):
- self._alpha = array_ops.identity(alpha, name="alpha")
- self._beta = array_ops.identity(beta, name="beta")
- contrib_tensor_util.assert_same_float_dtype((self._alpha, self._beta))
+ self._concentration = array_ops.identity(
+ concentration, name="concentration")
+ self._rate = array_ops.identity(rate, name="rate")
+ contrib_tensor_util.assert_same_float_dtype(
+ [self._concentration, self._rate])
super(Gamma, self).__init__(
- dtype=self._alpha.dtype,
+ dtype=self._concentration.dtype,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
is_continuous=True,
reparameterization_type=distribution.NOT_REPARAMETERIZED,
parameters=parameters,
- graph_parents=[self._alpha, self._beta],
+ graph_parents=[self._concentration,
+ self._rate],
name=ns)
@staticmethod
def _param_shapes(sample_shape):
return dict(
- zip(("alpha", "beta"), ([ops.convert_to_tensor(
+ zip(("concentration", "rate"), ([ops.convert_to_tensor(
sample_shape, dtype=dtypes.int32)] * 2)))
@property
- def alpha(self):
- """Shape parameter."""
- return self._alpha
+ def concentration(self):
+ """Concentration parameter."""
+ return self._concentration
@property
- def beta(self):
- """Inverse scale parameter."""
- return self._beta
+ def rate(self):
+ """Rate parameter."""
+ return self._rate
def _batch_shape_tensor(self):
return array_ops.broadcast_dynamic_shape(
- array_ops.shape(self.alpha), array_ops.shape(self.beta))
+ array_ops.shape(self.concentration),
+ array_ops.shape(self.rate))
def _batch_shape(self):
return array_ops.broadcast_static_shape(
- self.alpha.get_shape(), self.beta.get_shape())
+ self.concentration.get_shape(),
+ self.rate.get_shape())
def _event_shape_tensor(self):
return constant_op.constant([], dtype=dtypes.int32)
@@ -144,99 +180,101 @@ class Gamma(distribution.Distribution):
def _event_shape(self):
return tensor_shape.scalar()
+ @distribution_util.AppendDocstring(
+ """Note: See `tf.random_gamma` docstring for sampling details and
+ caveats.""")
def _sample_n(self, n, seed=None):
- """See the documentation for tf.random_gamma for more details."""
- return random_ops.random_gamma([n],
- self.alpha,
- beta=self.beta,
- dtype=self.dtype,
- seed=seed)
+ return random_ops.random_gamma(
+ shape=[n],
+ alpha=self.concentration,
+ beta=self.rate,
+ dtype=self.dtype,
+ seed=seed)
def _log_prob(self, x):
- x = control_flow_ops.with_dependencies([check_ops.assert_positive(x)] if
- self.validate_args else [], x)
- contrib_tensor_util.assert_same_float_dtype(tensors=[x],
- dtype=self.dtype)
- return (self.alpha * math_ops.log(self.beta) +
- (self.alpha - 1.) * math_ops.log(x) -
- self.beta * x -
- math_ops.lgamma(self.alpha))
+ return self._log_unnormalized_prob(x) - self._log_normalization()
def _prob(self, x):
return math_ops.exp(self._log_prob(x))
def _log_cdf(self, x):
- x = control_flow_ops.with_dependencies([check_ops.assert_positive(x)] if
- self.validate_args else [], x)
- contrib_tensor_util.assert_same_float_dtype(tensors=[x], dtype=self.dtype)
- # Note that igamma returns the regularized incomplete gamma function,
- # which is what we want for the CDF.
- return math_ops.log(math_ops.igamma(self.alpha, self.beta * x))
+ return math_ops.log(self._cdf(x))
def _cdf(self, x):
- return math_ops.igamma(self.alpha, self.beta * x)
+ x = self._maybe_assert_valid_sample(x)
+ # Note that igamma returns the regularized incomplete gamma function,
+ # which is what we want for the CDF.
+ return math_ops.igamma(self.concentration, self.rate * x)
- @distribution_util.AppendDocstring(
- """This is defined to be
+ def _log_unnormalized_prob(self, x):
+ x = self._maybe_assert_valid_sample(x)
+ return (self.concentration - 1.) * math_ops.log(x) - self.rate * x
- ```
- entropy = alpha - log(beta) + log(Gamma(alpha))
- + (1-alpha)digamma(alpha)
- ```
+ def _log_normalization(self):
+ return (math_ops.lgamma(self.concentration)
+ - self.concentration * math_ops.log(self.rate))
- where digamma(alpha) is the digamma function.
- """)
def _entropy(self):
- return (self.alpha -
- math_ops.log(self.beta) +
- math_ops.lgamma(self.alpha) +
- (1. - self.alpha) * math_ops.digamma(self.alpha))
+ return (self.concentration
+ - math_ops.log(self.rate)
+ + math_ops.lgamma(self.concentration)
+ + ((1. - self.concentration) *
+ math_ops.digamma(self.concentration)))
def _mean(self):
- return self.alpha / self.beta
+ return self.concentration / self.rate
def _variance(self):
- return self.alpha / math_ops.square(self.beta)
+ return self.concentration / math_ops.square(self.rate)
def _stddev(self):
- return math_ops.sqrt(self.alpha) / self.beta
+ return math_ops.sqrt(self.concentration) / self.rate
@distribution_util.AppendDocstring(
- """The mode of a gamma distribution is `(alpha - 1) / beta` when
- `alpha > 1`, and `NaN` otherwise. If `self.allow_nan_stats` is `False`,
+ """The mode of a gamma distribution is `(shape - 1) / rate` when
+ `shape > 1`, and `NaN` otherwise. If `self.allow_nan_stats` is `False`,
an exception will be raised rather than returning `NaN`.""")
def _mode(self):
- mode = (self.alpha - 1.) / self.beta
+ mode = (self.concentration - 1.) / self.rate
if self.allow_nan_stats:
- nan = np.array(np.nan, dtype=self.dtype.as_numpy_dtype())
- return array_ops.where(
- self.alpha >= 1.,
- mode,
- array_ops.fill(self.batch_shape_tensor(), nan, name="nan"))
+ nan = array_ops.fill(
+ self.batch_shape_tensor(),
+ np.array(np.nan, dtype=self.dtype.as_numpy_dtype()),
+ name="nan")
+ return array_ops.where(self.concentration > 1., mode, nan)
else:
return control_flow_ops.with_dependencies([
check_ops.assert_less(
- array_ops.ones((), self.dtype),
- self.alpha,
- message="mode not defined for components of alpha <= 1"),
+ array_ops.ones([], self.dtype),
+ self.concentration,
+ message="mode not defined when any concentration <= 1"),
], mode)
+ def _maybe_assert_valid_sample(self, x):
+ contrib_tensor_util.assert_same_float_dtype(tensors=[x], dtype=self.dtype)
+ if not self.validate_args:
+ return x
+ return control_flow_ops.with_dependencies([
+ check_ops.assert_positive(x),
+ ], x)
+
-class GammaWithSoftplusAlphaBeta(Gamma):
- """Gamma with softplus transform on `alpha` and `beta`."""
+class GammaWithSoftplusConcentrationRate(Gamma):
+ """`Gamma` with softplus of `concentration` and `rate`."""
def __init__(self,
- alpha,
- beta,
+ concentration,
+ rate,
validate_args=False,
allow_nan_stats=True,
- name="GammaWithSoftplusAlphaBeta"):
+ name="GammaWithSoftplusConcentrationRate"):
parameters = locals()
parameters.pop("self")
- with ops.name_scope(name, values=[alpha, beta]) as ns:
- super(GammaWithSoftplusAlphaBeta, self).__init__(
- alpha=nn.softplus(alpha, name="softplus_alpha"),
- beta=nn.softplus(beta, name="softplus_beta"),
+ with ops.name_scope(name, values=[concentration, rate]) as ns:
+ super(GammaWithSoftplusConcentrationRate, self).__init__(
+ concentration=nn.softplus(concentration,
+ name="softplus_concentration"),
+ rate=nn.softplus(rate, name="softplus_rate"),
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
name=ns)
@@ -256,15 +294,16 @@ def _kl_gamma_gamma(g0, g1, name=None):
Returns:
kl_gamma_gamma: `Tensor`. The batchwise KL(g0 || g1).
"""
- with ops.name_scope(name, "kl_gamma_gamma",
- values=[g0.alpha, g0.beta, g1.alpha, g1.beta]):
+ with ops.name_scope(name, "kl_gamma_gamma", values=[
+ g0.concentration, g0.rate, g1.concentration, g1.rate]):
# Result from:
# http://www.fil.ion.ucl.ac.uk/~wpenny/publications/densities.ps
# For derivation see:
# http://stats.stackexchange.com/questions/11646/kullback-leibler-divergence-between-two-gamma-distributions pylint: disable=line-too-long
- return ((g0.alpha - g1.alpha) * math_ops.digamma(g0.alpha)
- + math_ops.lgamma(g1.alpha)
- - math_ops.lgamma(g0.alpha)
- + g1.alpha * math_ops.log(g0.beta)
- - g1.alpha * math_ops.log(g1.beta)
- + g0.alpha * (g1.beta / g0.beta - 1.))
+ return (((g0.concentration - g1.concentration)
+ * math_ops.digamma(g0.concentration))
+ + math_ops.lgamma(g1.concentration)
+ - math_ops.lgamma(g0.concentration)
+ + g1.concentration * math_ops.log(g0.rate)
+ - g1.concentration * math_ops.log(g1.rate)
+ + g0.concentration * (g1.rate / g0.rate - 1.))
diff --git a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py
index 9fa6e55fa8..3bfc169c6b 100644
--- a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py
+++ b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py
@@ -22,6 +22,7 @@ import numpy as np
from tensorflow.contrib.distributions.python.ops import distribution
from tensorflow.contrib.distributions.python.ops import distribution_util
+from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -34,102 +35,144 @@ from tensorflow.python.ops import nn
from tensorflow.python.ops import random_ops
+__all__ = [
+ "InverseGamma",
+ "InverseGammaWithSoftplusConcentrationRate",
+]
+
+
class InverseGamma(distribution.Distribution):
- """The `InverseGamma` distribution with parameter alpha and beta.
+ """InverseGamma distribution.
- The parameters are the shape and inverse scale parameters alpha, beta.
+ The `InverseGamma` distribution is defined over positive real numbers using
+ parameters `concentration` (aka "alpha") and `rate` (aka "beta").
- The PDF of this distribution is:
+ #### Mathematical Details
- ```pdf(x) = (beta^alpha)/Gamma(alpha)(x^(-alpha-1))e^(-beta/x), x > 0```
+ The probability density function (pdf) is,
- and the CDF of this distribution is:
+ ```none
+ pdf(x; alpha, beta, x > 0) = x**(-alpha - 1) exp(-beta / x) / Z
+ Z = Gamma(alpha) beta**-alpha
+ ```
- ```cdf(x) = GammaInc(alpha, beta / x) / Gamma(alpha), x > 0```
+ where:
- where GammaInc is the upper incomplete Gamma function.
+ * `concentration = alpha`,
+ * `rate = beta`,
+ * `Z` is the normalizing constant, and,
+ * `Gamma` is the [gamma function](
+ https://en.wikipedia.org/wiki/Gamma_function).
- Examples:
+ The cumulative density function (cdf) is,
+
+ ```none
+ cdf(x; alpha, beta, x > 0) = GammaInc(alpha, beta / x) / Gamma(alpha)
+ ```
+
+ where `GammaInc` is the [upper incomplete Gamma function](
+ https://en.wikipedia.org/wiki/Incomplete_gamma_function).
+
+ The parameters can be intuited via their relationship to mean and stddev,
+
+ ```none
+ concentration = alpha = (mean / stddev)**2
+ rate = beta = mean / stddev**2
+ ```
+
+ Distribution parameters are automatically broadcast in all functions; see
+ examples for details.
+
+ WARNING: This distribution may draw 0-valued samples for small concentration
+ values. See note in `tf.random_gamma` docstring.
+
+ #### Examples
```python
- dist = InverseGamma(alpha=3.0, beta=2.0)
- dist2 = InverseGamma(alpha=[3.0, 4.0], beta=[2.0, 3.0])
+ dist = InverseGamma(concentration=3.0, rate=2.0)
+ dist2 = InverseGamma(concentration=[3.0, 4.0], rate=[2.0, 3.0])
```
"""
def __init__(self,
- alpha,
- beta,
+ concentration,
+ rate,
validate_args=False,
allow_nan_stats=True,
name="InverseGamma"):
- """Construct InverseGamma distributions with parameters `alpha` and `beta`.
+ """Construct InverseGamma with `concentration` and `rate` parameters.
- The parameters `alpha` and `beta` must be shaped in a way that supports
- broadcasting (e.g. `alpha + beta` is a valid operation).
+ The parameters `concentration` and `rate` must be shaped in a way that
+ supports broadcasting (e.g. `concentration + rate` is a valid operation).
Args:
- alpha: Floating point tensor, the shape params of the
- distribution(s).
- alpha must contain only positive values.
- beta: Floating point tensor, the scale params of the distribution(s).
- beta must contain only positive values.
- validate_args: `Boolean`, default `False`. Whether to assert that
- `a > 0`, `b > 0`, and that `x > 0` in the methods `prob(x)` and
- `log_prob(x)`. If `validate_args` is `False` and the inputs are
- invalid, correct behavior is not guaranteed.
- allow_nan_stats: `Boolean`, default `True`. If `False`, raise an
- exception if a statistic (e.g. mean/mode/etc...) is undefined for any
- batch member. If `True`, batch members with valid parameters leading to
- undefined statistics will return NaN for this statistic.
- name: The name to prepend to all ops created by this distribution.
+ concentration: Floating point tensor, the concentration params of the
+ distribution(s). Must contain only positive values.
+ rate: Floating point tensor, the inverse scale params of the
+ distribution(s). Must contain only positive values.
+ validate_args: Python `Boolean`, default `False`. When `True` distribution
+ parameters are checked for validity despite possibly degrading runtime
+ performance. When `False` invalid inputs may silently render incorrect
+ outputs.
+ allow_nan_stats: Python `Boolean`, default `True`. When `True`, statistics
+ (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
+ result is undefined. When `False`, an exception is raised if one or
+ more of the statistic's batch members are undefined.
+ name: `String` name prefixed to Ops created by this class.
+
Raises:
- TypeError: if `alpha` and `beta` are different dtypes.
+ TypeError: if `concentration` and `rate` are different dtypes.
"""
parameters = locals()
parameters.pop("self")
- with ops.name_scope(name, values=[alpha, beta]) as ns:
+ with ops.name_scope(name, values=[concentration, rate]) as ns:
with ops.control_dependencies([
- check_ops.assert_positive(alpha),
- check_ops.assert_positive(beta),
+ check_ops.assert_positive(concentration),
+ check_ops.assert_positive(rate),
] if validate_args else []):
- self._alpha = array_ops.identity(alpha, name="alpha")
- self._beta = array_ops.identity(beta, name="beta")
+ self._concentration = array_ops.identity(
+ concentration, name="concentration")
+ self._rate = array_ops.identity(rate, name="rate")
+ contrib_tensor_util.assert_same_float_dtype(
+ [self._concentration, self._rate])
super(InverseGamma, self).__init__(
- dtype=self._alpha.dtype,
+ dtype=self._concentration.dtype,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
is_continuous=True,
reparameterization_type=distribution.NOT_REPARAMETERIZED,
parameters=parameters,
- graph_parents=[self._alpha, self._beta],
+ graph_parents=[self._concentration,
+ self._rate],
name=ns)
@staticmethod
def _param_shapes(sample_shape):
return dict(
- zip(("alpha", "beta"), ([ops.convert_to_tensor(
+ zip(("concentration", "rate"), ([ops.convert_to_tensor(
sample_shape, dtype=dtypes.int32)] * 2)))
@property
- def alpha(self):
- """Shape parameter."""
- return self._alpha
+ def concentration(self):
+ """Concentration parameter."""
+ return self._concentration
@property
- def beta(self):
- """Scale parameter."""
- return self._beta
+ def rate(self):
+ """Rate parameter."""
+ return self._rate
def _batch_shape_tensor(self):
return array_ops.broadcast_dynamic_shape(
- array_ops.shape(self.alpha), array_ops.shape(self.beta))
+ array_ops.shape(self.concentration),
+ array_ops.shape(self.rate))
def _batch_shape(self):
return array_ops.broadcast_static_shape(
- self.alpha.get_shape(), self.beta.get_shape())
+ self.concentration.get_shape(),
+ self.rate.get_shape())
def _event_shape_tensor(self):
return constant_op.constant([], dtype=dtypes.int32)
@@ -137,17 +180,19 @@ class InverseGamma(distribution.Distribution):
def _event_shape(self):
return tensor_shape.scalar()
+ @distribution_util.AppendDocstring(
+ """Note: See `tf.random_gamma` docstring for sampling details and
+ caveats.""")
def _sample_n(self, n, seed=None):
- """See the documentation for tf.random_gamma for more details."""
- return 1. / random_ops.random_gamma([n], self.alpha, beta=self.beta,
- dtype=self.dtype, seed=seed)
+ return 1. / random_ops.random_gamma(
+ shape=[n],
+ alpha=self.concentration,
+ beta=self.rate,
+ dtype=self.dtype,
+ seed=seed)
def _log_prob(self, x):
- x = control_flow_ops.with_dependencies([check_ops.assert_positive(x)] if
- self.validate_args else [], x)
- return (self.alpha * math_ops.log(self.beta) -
- math_ops.lgamma(self.alpha) -
- (self.alpha + 1.) * math_ops.log(x) - self.beta / x)
+ return self._log_unnormalized_prob(x) - self._log_normalization()
def _prob(self, x):
return math_ops.exp(self._log_prob(x))
@@ -156,84 +201,100 @@ class InverseGamma(distribution.Distribution):
return math_ops.log(self._cdf(x))
def _cdf(self, x):
- x = control_flow_ops.with_dependencies([check_ops.assert_positive(x)] if
- self.validate_args else [], x)
+ x = self._maybe_assert_valid_sample(x)
# Note that igammac returns the upper regularized incomplete gamma
# function Q(a, x), which is what we want for the CDF.
- return math_ops.igammac(self.alpha, self.beta / x)
+ return math_ops.igammac(self.concentration, self.rate / x)
- @distribution_util.AppendDocstring(
- """This is defined to be
+ def _log_unnormalized_prob(self, x):
+ x = self._maybe_assert_valid_sample(x)
+ return -(1. + self.concentration) * math_ops.log(x) - self.rate / x
- ```
- entropy = alpha - log(beta) + log(Gamma(alpha))
- + (1-alpha)digamma(alpha)
- ```
+ def _log_normalization(self):
+ return (math_ops.lgamma(self.concentration)
+ - self.concentration * math_ops.log(self.rate))
- where digamma(alpha) is the digamma function.""")
def _entropy(self):
- return (self.alpha +
- math_ops.log(self.beta) +
- math_ops.lgamma(self.alpha) -
- (1. + self.alpha) * math_ops.digamma(self.alpha))
+ return (self.concentration
+ + math_ops.log(self.rate)
+ + math_ops.lgamma(self.concentration)
+ - ((1. + self.concentration) *
+ math_ops.digamma(self.concentration)))
@distribution_util.AppendDocstring(
- """The mean of an inverse gamma distribution is `beta / (alpha - 1)`,
- when `alpha > 1`, and `NaN` otherwise. If `self.allow_nan_stats` is
- `False`, an exception will be raised rather than returning `NaN`""")
+ """The mean of an inverse gamma distribution is
+ `rate / (concentration - 1)`, when `concentration > 1`, and `NaN`
+ otherwise. If `self.allow_nan_stats` is `False`, an exception will be
+ raised rather than returning `NaN`""")
def _mean(self):
- mean = self.beta / (self.alpha - 1.)
+ mean = self.rate / (self.concentration - 1.)
if self.allow_nan_stats:
- nan = np.array(np.nan, dtype=self.dtype.as_numpy_dtype())
- return array_ops.where(
- self.alpha > 1., mean,
- array_ops.fill(self.batch_shape_tensor(), nan, name="nan"))
+ nan = array_ops.fill(
+ self.batch_shape_tensor(),
+ np.array(np.nan, dtype=self.dtype.as_numpy_dtype()),
+ name="nan")
+ return array_ops.where(self.concentration > 1., mean, nan)
else:
return control_flow_ops.with_dependencies([
check_ops.assert_less(
- array_ops.ones((), self.dtype), self.alpha,
- message="mean not defined for components of self.alpha <= 1"),
+ array_ops.ones([], self.dtype), self.concentration,
+ message="mean undefined when any concentration <= 1"),
], mean)
@distribution_util.AppendDocstring(
- """Variance for inverse gamma is defined only for `alpha > 2`. If
+ """Variance for inverse gamma is defined only for `concentration > 2`. If
`self.allow_nan_stats` is `False`, an exception will be raised rather
than returning `NaN`.""")
def _variance(self):
- var = (math_ops.square(self.beta) /
- (math_ops.square(self.alpha - 1.) * (self.alpha - 2.)))
+ var = (math_ops.square(self.rate)
+ / math_ops.square(self.concentration - 1.)
+ / (self.concentration - 2.))
if self.allow_nan_stats:
- nan = np.array(np.nan, dtype=self.dtype.as_numpy_dtype())
- return array_ops.where(
- self.alpha > 2., var,
- array_ops.fill(self.batch_shape_tensor(), nan, name="nan"))
+ nan = array_ops.fill(
+ self.batch_shape_tensor(),
+ np.array(np.nan, dtype=self.dtype.as_numpy_dtype()),
+ name="nan")
+ return array_ops.where(self.concentration > 2., var, nan)
else:
return control_flow_ops.with_dependencies([
check_ops.assert_less(
- constant_op.constant(2., dtype=self.dtype), self.alpha,
- message="variance not defined for components of alpha <= 2"),
+ constant_op.constant(2., dtype=self.dtype),
+ self.concentration,
+ message="variance undefined when any concentration <= 2"),
], var)
+ @distribution_util.AppendDocstring(
+ """The mode of an inverse gamma distribution is `rate / (concentration +
+ 1)`.""")
def _mode(self):
- """The mode of an inverse gamma distribution is `beta / (alpha + 1)`."""
- return self.beta / (self.alpha + 1.)
+ return self.rate / (1. + self.concentration)
+
+ def _maybe_assert_valid_sample(self, x):
+ contrib_tensor_util.assert_same_float_dtype(
+ tensors=[x], dtype=self.dtype)
+ if not self.validate_args:
+ return x
+ return control_flow_ops.with_dependencies([
+ check_ops.assert_positive(x),
+ ], x)
-class InverseGammaWithSoftplusAlphaBeta(InverseGamma):
- """Inverse Gamma with softplus applied to `alpha` and `beta`."""
+class InverseGammaWithSoftplusConcentrationRate(InverseGamma):
+ """`InverseGamma` with softplus of `concentration` and `rate`."""
def __init__(self,
- alpha,
- beta,
+ concentration,
+ rate,
validate_args=False,
allow_nan_stats=True,
- name="InverseGammaWithSoftplusAlphaBeta"):
+ name="InverseGammaWithSoftplusConcentrationRate"):
parameters = locals()
parameters.pop("self")
- with ops.name_scope(name, values=[alpha, beta]) as ns:
- super(InverseGammaWithSoftplusAlphaBeta, self).__init__(
- alpha=nn.softplus(alpha, name="softplus_alpha"),
- beta=nn.softplus(beta, name="softplus_gamma"),
+ with ops.name_scope(name, values=[concentration, rate]) as ns:
+ super(InverseGammaWithSoftplusConcentrationRate, self).__init__(
+ concentration=nn.softplus(concentration,
+ name="softplus_concentration"),
+ rate=nn.softplus(rate, name="softplus_rate"),
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
name=ns)
diff --git a/tensorflow/contrib/distributions/python/ops/poisson.py b/tensorflow/contrib/distributions/python/ops/poisson.py
index cabb99d106..fa516141fb 100644
--- a/tensorflow/contrib/distributions/python/ops/poisson.py
+++ b/tensorflow/contrib/distributions/python/ops/poisson.py
@@ -30,9 +30,14 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
-_poisson_prob_note = """
-Note thet the input value must be a non-negative floating point tensor with
-dtype `dtype` and whose shape can be broadcast with `self.lam`. `x` is only
+__all__ = [
+ "Poisson",
+]
+
+
+_poisson_sample_note = """
+Note that the input value must be a non-negative floating point tensor with
+dtype `dtype` and whose shape can be broadcast with `self.rate`. `x` is only
legal if it is non-negative and its components are equal to integer values.
"""
@@ -40,63 +45,67 @@ legal if it is non-negative and its components are equal to integer values.
class Poisson(distribution.Distribution):
"""Poisson distribution.
- The Poisson distribution is parameterized by `lam`, the rate parameter.
+ The Poisson distribution is parameterized by an event `rate` parameter.
- The pmf of this distribution is:
+ #### Mathematical Details
- ```
+ The probability mass function (pmf) is,
- pmf(k) = e^(-lam) * lam^k / k!, k >= 0
+ ```none
+ pmf(k; lambda, k >= 0) = (lambda^k / k!) / Z
+ Z = exp(lambda).
```
+ where `rate = lambda` and `Z` is the normalizing constant.
+
"""
def __init__(self,
- lam,
+ rate,
validate_args=False,
allow_nan_stats=True,
name="Poisson"):
- """Construct Poisson distributions.
+ """Initialize a batch of Poisson distributions.
Args:
- lam: Floating point tensor, the rate parameter of the
- distribution(s). `lam` must be positive.
- validate_args: `Boolean`, default `False`. Whether to assert that
- `lam > 0` as well as inputs to `prob` computations are non-negative
- integers. If validate_args is `False`, then `prob` computations might
- return `NaN`, but can be evaluated at any real value.
- allow_nan_stats: `Boolean`, default `True`. If `False`, raise an
- exception if a statistic (e.g. mean/mode/etc...) is undefined for any
- batch member. If `True`, batch members with valid parameters leading to
- undefined statistics will return NaN for this statistic.
- name: A name for this distribution.
+ rate: Floating point tensor, the rate parameter of the
+ distribution(s). `rate` must be positive.
+ validate_args: Python `Boolean`, default `False`. When `True` distribution
+ parameters are checked for validity despite possibly degrading runtime
+ performance. When `False` invalid inputs may silently render incorrect
+ outputs.
+ allow_nan_stats: Python `Boolean`, default `True`. When `True`, statistics
+ (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
+ result is undefined. When `False`, an exception is raised if one or
+ more of the statistic's batch members are undefined.
+ name: `String` name prefixed to Ops created by this class.
"""
parameters = locals()
parameters.pop("self")
- with ops.name_scope(name, values=[lam]) as ns:
- with ops.control_dependencies([check_ops.assert_positive(lam)] if
+ with ops.name_scope(name, values=[rate]) as ns:
+ with ops.control_dependencies([check_ops.assert_positive(rate)] if
validate_args else []):
- self._lam = array_ops.identity(lam, name="lam")
+ self._rate = array_ops.identity(rate, name="rate")
super(Poisson, self).__init__(
- dtype=self._lam.dtype,
+ dtype=self._rate.dtype,
is_continuous=False,
reparameterization_type=distribution.NOT_REPARAMETERIZED,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
parameters=parameters,
- graph_parents=[self._lam],
+ graph_parents=[self._rate],
name=ns)
@property
- def lam(self):
+ def rate(self):
"""Rate parameter."""
- return self._lam
+ return self._rate
def _batch_shape_tensor(self):
- return array_ops.shape(self.lam)
+ return array_ops.shape(self.rate)
def _batch_shape(self):
- return self.lam.get_shape()
+ return self.rate.get_shape()
def _event_shape_tensor(self):
return constant_op.constant([], dtype=dtypes.int32)
@@ -104,40 +113,47 @@ class Poisson(distribution.Distribution):
def _event_shape(self):
return tensor_shape.scalar()
- @distribution_util.AppendDocstring(_poisson_prob_note)
+ @distribution_util.AppendDocstring(_poisson_sample_note)
def _log_prob(self, x):
- x = self._assert_valid_sample(x, check_integer=True)
- return x * math_ops.log(self.lam) - self.lam - math_ops.lgamma(x + 1)
+ return self._log_unnormalized_prob(x) - self._log_normalization()
- @distribution_util.AppendDocstring(_poisson_prob_note)
+ @distribution_util.AppendDocstring(_poisson_sample_note)
def _prob(self, x):
return math_ops.exp(self._log_prob(x))
+ @distribution_util.AppendDocstring(_poisson_sample_note)
def _log_cdf(self, x):
return math_ops.log(self.cdf(x))
+ @distribution_util.AppendDocstring(_poisson_sample_note)
def _cdf(self, x):
x = self._assert_valid_sample(x, check_integer=False)
- return math_ops.igammac(math_ops.floor(x + 1), self.lam)
+ return math_ops.igammac(math_ops.floor(x + 1), self.rate)
+
+ def _log_normalization(self):
+ return self.rate
+
+ def _log_unnormalized_prob(self, x):
+ x = self._assert_valid_sample(x, check_integer=True)
+ return x * math_ops.log(self.rate) - math_ops.lgamma(x + 1)
def _mean(self):
- return array_ops.identity(self.lam)
+ return array_ops.identity(self.rate)
def _variance(self):
- return array_ops.identity(self.lam)
+ return array_ops.identity(self.rate)
@distribution_util.AppendDocstring(
- """Note that when `lam` is an integer, there are actually two modes.
- Namely, `lam` and `lam - 1` are both modes. Here we return
- only the larger of the two modes.""")
+ """Note: when `rate` is an integer, there are actually two modes: `rate`
+ and `rate - 1`. In this case we return the larger, i.e., `rate`.""")
def _mode(self):
- return math_ops.floor(self.lam)
+ return math_ops.floor(self.rate)
def _assert_valid_sample(self, x, check_integer=True):
- if not self.validate_args: return x
- with ops.name_scope("check_x", values=[x]):
- dependencies = [check_ops.assert_non_negative(x)]
- if check_integer:
- dependencies += [distribution_util.assert_integer_form(
- x, message="x has non-integer components.")]
- return control_flow_ops.with_dependencies(dependencies, x)
+ if not self.validate_args:
+ return x
+ dependencies = [check_ops.assert_non_negative(x)]
+ if check_integer:
+ dependencies += [distribution_util.assert_integer_form(
+ x, message="x has non-integer components.")]
+ return control_flow_ops.with_dependencies(dependencies, x)
diff --git a/tensorflow/contrib/distributions/python/ops/wishart.py b/tensorflow/contrib/distributions/python/ops/wishart.py
index 24e48f0a8d..5f5c74184a 100644
--- a/tensorflow/contrib/distributions/python/ops/wishart.py
+++ b/tensorflow/contrib/distributions/python/ops/wishart.py
@@ -38,6 +38,12 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
+__all__ = [
+ "WishartCholesky",
+ "WishartFull",
+]
+
+
class _WishartOperatorPD(distribution.Distribution):
"""The matrix Wishart distribution on positive definite matrices.
@@ -45,22 +51,22 @@ class _WishartOperatorPD(distribution.Distribution):
an instance of `OperatorPDBase`, which provides matrix-free access to a
symmetric positive definite operator, which defines the scale matrix.
- #### Mathematical details.
+ #### Mathematical Details
- The PDF of this distribution is,
+ The probability density function (pdf) is,
+ ```none
+ pdf(X; df, scale) = det(X)**(0.5 (df-k-1)) exp(-0.5 tr[inv(scale) X]) / Z
+ Z = 2**(0.5 df k) |det(scale)|**(0.5 df) Gamma_k(0.5 df)
```
- f(X) = det(X)^(0.5 (df-k-1)) exp(-0.5 tr[inv(scale) X]) / B(scale, df)
- ```
-
- where `df >= k` denotes the degrees of freedom, `scale` is a symmetric, pd,
- `k x k` matrix, and the normalizing constant `B(scale, df)` is given by:
- ```
- B(scale, df) = 2^(0.5 df k) |det(scale)|^(0.5 df) Gamma_k(0.5 df)
- ```
+ where:
- where `Gamma_k` is the multivariate Gamma function.
+ * `df >= k` denotes the degrees of freedom,
+ * `scale` is a symmetric, positive definite, `k x k` matrix,
+ * `Z` is the normalizing constant, and,
+ * `Gamma_k` is the [multivariate Gamma function](
+ https://en.wikipedia.org/wiki/Multivariate_gamma_function).
#### Examples
@@ -86,19 +92,21 @@ class _WishartOperatorPD(distribution.Distribution):
Cholesky factored matrix. Example `log_prob` input takes a Cholesky and
`sample_n` returns a Cholesky when
`cholesky_input_output_matrices=True`.
- validate_args: `Boolean`, default `False`. Whether to validate input with
- asserts. If `validate_args` is `False`, and the inputs are invalid,
- correct behavior is not guaranteed.
- allow_nan_stats: `Boolean`, default `True`. If `False`, raise an
- exception if a statistic (e.g., mean, mode) is undefined for any batch
- member. If True, batch members with valid parameters leading to
- undefined statistics will return `NaN` for this statistic.
- name: The name to give Ops created by the initializer.
+ validate_args: Python `Boolean`, default `False`. When `True` distribution
+ parameters are checked for validity despite possibly degrading runtime
+ performance. When `False` invalid inputs may silently render incorrect
+ outputs.
+ allow_nan_stats: Python `Boolean`, default `True`. When `True`, statistics
+ (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
+ result is undefined. When `False`, an exception is raised if one or
+ more of the statistic's batch members are undefined.
+ name: `String` name prefixed to Ops created by this class.
Raises:
TypeError: if scale is not floating-type
TypeError: if scale.dtype != df.dtype
- ValueError: if df < k, where scale operator event shape is `(k, k)`
+ ValueError: if df < k, where scale operator event shape is
+ `(k, k)`
"""
parameters = locals()
parameters.pop("self")
@@ -111,7 +119,9 @@ class _WishartOperatorPD(distribution.Distribution):
scale_operator_pd.dtype)
self._scale_operator_pd = scale_operator_pd
self._df = ops.convert_to_tensor(
- df, dtype=scale_operator_pd.dtype, name="df")
+ df,
+ dtype=scale_operator_pd.dtype,
+ name="df")
contrib_tensor_util.assert_same_float_dtype(
(self._df, self._scale_operator_pd))
if (self._scale_operator_pd.get_shape().ndims is None or
@@ -127,19 +137,22 @@ class _WishartOperatorPD(distribution.Distribution):
dim_val = tensor_util.constant_value(self._dimension)
if df_val is not None and dim_val is not None:
df_val = np.asarray(df_val)
- if not df_val.shape: df_val = (df_val,)
+ if not df_val.shape:
+ df_val = [df_val]
if any(df_val < dim_val):
raise ValueError(
- "Degrees of freedom (df = %s) cannot be less than dimension of "
- "scale matrix (scale.dimension = %s)"
+ "Degrees of freedom (df = %s) cannot be less than "
+ "dimension of scale matrix (scale.dimension = %s)"
% (df_val, dim_val))
elif validate_args:
assertions = check_ops.assert_less_equal(
self._dimension, self._df,
- message=("Degrees of freedom (df = %s) cannot be less than "
- "dimension of scale matrix (scale.dimension = %s)" %
+ message=("Degrees of freedom (df = %s) cannot be "
+ "less than dimension of scale matrix "
+ "(scale.dimension = %s)" %
(self._dimension, self._df)))
- self._df = control_flow_ops.with_dependencies([assertions], self._df)
+ self._df = control_flow_ops.with_dependencies(
+ [assertions], self._df)
super(_WishartOperatorPD, self).__init__(
dtype=self._scale_operator_pd.dtype,
validate_args=validate_args,
@@ -321,7 +334,7 @@ class _WishartOperatorPD(distribution.Distribution):
# Complexity: O(nbk^2)
log_prob = ((self.df - self.dimension - 1.) * half_log_det_x -
0.5 * trace_scale_inv_x -
- self.log_normalizing_constant())
+ self.log_normalization())
# Set shape hints.
# Try to merge what we know from the input then what we know from the
@@ -349,7 +362,8 @@ class _WishartOperatorPD(distribution.Distribution):
def _mean(self):
if self.cholesky_input_output_matrices:
- return math_ops.sqrt(self.df) * self.scale_operator_pd.sqrt_to_dense()
+ return (math_ops.sqrt(self.df)
+ * self.scale_operator_pd.sqrt_to_dense())
return self.df * self.scale_operator_pd.to_dense()
def _variance(self):
@@ -384,7 +398,7 @@ class _WishartOperatorPD(distribution.Distribution):
self.dimension * math.log(2.) +
self.scale_operator_pd.log_det())
- def log_normalizing_constant(self, name="log_normalizing_constant"):
+ def log_normalization(self, name="log_normalization"):
"""Computes the log normalizing constant, log(Z)."""
with self._name_scope(name):
return (self.df * self.scale_operator_pd.sqrt_log_det() +
@@ -429,22 +443,21 @@ class WishartCholesky(_WishartOperatorPD):
another O(nbk^3) operation since most uses of Wishart will also use the
Cholesky factorization.
- #### Mathematical details.
-
- The PDF of this distribution is,
-
- ```
- f(X) = det(X)^(0.5 (df-k-1)) exp(-0.5 tr[inv(scale) X]) / B(scale, df)
- ```
+ #### Mathematical Details
- where `df >= k` denotes the degrees of freedom, `scale` is a symmetric, pd,
- `k x k` matrix, and the normalizing constant `B(scale, df)` is given by:
+ The probability density function (pdf) is,
- ```
- B(scale, df) = 2^(0.5 df k) |det(scale)|^(0.5 df) Gamma_k(0.5 df)
+ ```none
+ pdf(X; df, scale) = det(X)**(0.5 (df-k-1)) exp(-0.5 tr[inv(scale) X]) / Z
+ Z = 2**(0.5 df k) |det(scale)|**(0.5 df) Gamma_k(0.5 df)
```
- where `Gamma_k` is the multivariate Gamma function.
+ where:
+ * `df >= k` denotes the degrees of freedom,
+ * `scale` is a symmetric, positive definite, `k x k` matrix,
+ * `Z` is the normalizing constant, and,
+ * `Gamma_k` is the [multivariate Gamma function](
+ https://en.wikipedia.org/wiki/Multivariate_gamma_function).
#### Examples
@@ -499,14 +512,15 @@ class WishartCholesky(_WishartOperatorPD):
Cholesky factored matrix. Example `log_prob` input takes a Cholesky and
`sample_n` returns a Cholesky when
`cholesky_input_output_matrices=True`.
- validate_args: `Boolean`, default `False`. Whether to validate input
- with asserts. If `validate_args` is `False`, and the inputs are invalid,
- correct behavior is not guaranteed.
- allow_nan_stats: `Boolean`, default `True`. If `False`, raise an
- exception if a statistic (e.g., mean, mode) is undefined for any batch
- member. If True, batch members with valid parameters leading to
- undefined statistics will return `NaN` for this statistic.
- name: The name scope to give class member ops.
+ validate_args: Python `Boolean`, default `False`. When `True` distribution
+ parameters are checked for validity despite possibly degrading runtime
+ performance. When `False` invalid inputs may silently render incorrect
+ outputs.
+ allow_nan_stats: Python `Boolean`, default `True`. When `True`, statistics
+ (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
+ result is undefined. When `False`, an exception is raised if one or
+ more of the statistic's batch members are undefined.
+ name: `String` name prefixed to Ops created by this class.
"""
parameters = locals()
parameters.pop("self")
@@ -531,22 +545,21 @@ class WishartFull(_WishartOperatorPD):
Evaluation of the pdf, determinant, and sampling are all `O(k^3)` operations
where `(k, k)` is the event space shape.
- #### Mathematical details.
-
- The PDF of this distribution is,
+ #### Mathematical Details
- ```
- f(X) = det(X)^(0.5 (df-k-1)) exp(-0.5 tr[inv(scale) X]) / B(scale, df)
- ```
+ The probability density function (pdf) is,
- where `df >= k` denotes the degrees of freedom, `scale` is a symmetric, pd,
- `k x k` matrix, and the normalizing constant `B(scale, df)` is given by:
-
- ```
- B(scale, df) = 2^(0.5 df k) |det(scale)|^(0.5 df) Gamma_k(0.5 df)
+ ```none
+ pdf(X; df, scale) = det(X)**(0.5 (df-k-1)) exp(-0.5 tr[inv(scale) X]) / Z
+ Z = 2**(0.5 df k) |det(scale)|**(0.5 df) Gamma_k(0.5 df)
```
- where `Gamma_k` is the multivariate Gamma function.
+ where:
+ * `df >= k` denotes the degrees of freedom,
+ * `scale` is a symmetric, positive definite, `k x k` matrix,
+ * `Z` is the normalizing constant, and,
+ * `Gamma_k` is the [multivariate Gamma function](
+ https://en.wikipedia.org/wiki/Multivariate_gamma_function).
#### Examples
@@ -600,14 +613,15 @@ class WishartFull(_WishartOperatorPD):
Cholesky factored matrix. Example `log_prob` input takes a Cholesky and
`sample_n` returns a Cholesky when
`cholesky_input_output_matrices=True`.
- validate_args: `Boolean`, default `False`. Whether to validate input with
- asserts. If `validate_args` is `False`, and the inputs are invalid,
- correct behavior is not guaranteed.
- allow_nan_stats: `Boolean`, default `True`. If `False`, raise an
- exception if a statistic (e.g., mean, mode) is undefined for any batch
- member. If True, batch members with valid parameters leading to
- undefined statistics will return `NaN` for this statistic.
- name: The name scope to give class member ops.
+ validate_args: Python `Boolean`, default `False`. When `True` distribution
+ parameters are checked for validity despite possibly degrading runtime
+ performance. When `False` invalid inputs may silently render incorrect
+ outputs.
+ allow_nan_stats: Python `Boolean`, default `True`. When `True`, statistics
+ (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
+ result is undefined. When `False`, an exception is raised if one or
+ more of the statistic's batch members are undefined.
+ name: `String` name prefixed to Ops created by this class.
"""
parameters = locals()
parameters.pop("self")