aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-07-22 10:40:20 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-22 11:47:39 -0700
commitf780673ff8e33b69a3ce0eab37bf8efd0ba26f18 (patch)
treea3a33b5089c47b381fd6e147c55253943f1c0fa3
parenta81c4f9cd01563e97fc6f179e4d70960fc9b02ae (diff)
Add DirichletMultinomial variance calculation.
Change: 128195102
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/dirichlet_multinomial_test.py80
-rw-r--r--tensorflow/contrib/distributions/python/ops/dirichlet_multinomial.py42
2 files changed, 122 insertions, 0 deletions
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_multinomial_test.py b/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_multinomial_test.py
index aec5b85699..866fb45524 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_multinomial_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_multinomial_test.py
@@ -198,6 +198,86 @@ class DirichletMultinomialTest(tf.test.TestCase):
self.assertAllClose(mean2[class_num], 2 * mean1[class_num])
self.assertTupleEqual((3,), mean1.shape)
+ def testVariance(self):
+ # Shape [2]
+ alpha = [1., 2]
+ ns = [2., 3., 4., 5.]
+ alpha_0 = np.sum(alpha)
+
+ # Diagonal entries are of the form:
+ # Var(X_i) = n * alpha_i / alpha_sum * (1 - alpha_i / alpha_sum) *
+ # (alpha_sum + n) / (alpha_sum + 1)
+ variance_entry = lambda a, a_sum: a / a_sum * (1 - a / a_sum)
+ # Off diagonal entries are of the form:
+ # Cov(X_i, X_j) = -n * alpha_i * alpha_j / (alpha_sum ** 2) *
+ # (alpha_sum + n) / (alpha_sum + 1)
+ covariance_entry = lambda a, b, a_sum: -a * b/ a_sum**2
+ # Shape [2, 2].
+ shared_matrix = np.array([
+ [variance_entry(alpha[0], alpha_0),
+ covariance_entry(alpha[0], alpha[1], alpha_0)],
+ [covariance_entry(alpha[1], alpha[0], alpha_0),
+ variance_entry(alpha[1], alpha_0)]])
+
+ with self.test_session():
+ for n in ns:
+ # n is shape [] and alpha is shape [2].
+ dist = tf.contrib.distributions.DirichletMultinomial(n, alpha)
+ variance = dist.variance()
+ expected_variance = n * (n + alpha_0) / (1 + alpha_0) * shared_matrix
+
+ self.assertEqual((2, 2), variance.get_shape())
+ self.assertAllClose(expected_variance, variance.eval())
+
+ def testVariance_n_alpha_broadcast(self):
+ alpha_v = [1., 2, 3]
+ alpha_0 = np.sum(alpha_v)
+
+ # Shape [4, 3]
+ alpha = np.array(4 * [alpha_v])
+ # Shape [4, 1]
+ ns = np.array([[2.], [3.], [4.], [5.]])
+
+ variance_entry = lambda a, a_sum: a / a_sum * (1 - a / a_sum)
+ covariance_entry = lambda a, b, a_sum: -a * b/ a_sum**2
+ # Shape [4, 3, 3]
+ shared_matrix = np.array(4 * [[
+ [variance_entry(alpha_v[0], alpha_0),
+ covariance_entry(alpha_v[0], alpha_v[1], alpha_0),
+ covariance_entry(alpha_v[0], alpha_v[2], alpha_0)],
+ [covariance_entry(alpha_v[1], alpha_v[0], alpha_0),
+ variance_entry(alpha_v[1], alpha_0),
+ covariance_entry(alpha_v[1], alpha_v[2], alpha_0)],
+ [covariance_entry(alpha_v[2], alpha_v[0], alpha_0),
+ covariance_entry(alpha_v[2], alpha_v[1], alpha_0),
+ variance_entry(alpha_v[2], alpha_0)]]])
+
+ with self.test_session():
+ # ns is shape [4, 1], and alpha is shape [4, 3].
+ dist = tf.contrib.distributions.DirichletMultinomial(ns, alpha)
+ variance = dist.variance()
+ expected_variance = np.expand_dims(
+ ns * (ns + alpha_0) / (1 + alpha_0), -1) * shared_matrix
+
+ self.assertEqual((4, 3, 3), variance.get_shape())
+ self.assertAllClose(expected_variance, variance.eval())
+
+ def testVariance_multidimensional(self):
+ alpha = np.random.rand(3, 5, 4)
+ alpha2 = np.random.rand(6, 3, 3)
+ # Ensure n > 0.
+ ns = np.random.geometric(p=0.8, size=[3, 5, 1]) + 1
+ ns2 = np.random.geometric(p=0.8, size=[6, 1, 1]) + 1
+
+ with self.test_session():
+ dist = tf.contrib.distributions.DirichletMultinomial(ns, alpha)
+ dist2 = tf.contrib.distributions.DirichletMultinomial(ns2, alpha2)
+
+ variance = dist.variance()
+ variance2 = dist2.variance()
+ self.assertEqual((3, 5, 4, 4), variance.get_shape())
+ self.assertEqual((6, 3, 3, 3), variance2.get_shape())
+
def testZeroCountsResultsInPmfEqualToOne(self):
# There is only one way for zero items to be selected, and this happens with
# probability 1.
diff --git a/tensorflow/contrib/distributions/python/ops/dirichlet_multinomial.py b/tensorflow/contrib/distributions/python/ops/dirichlet_multinomial.py
index c20590ce35..6982a73381 100644
--- a/tensorflow/contrib/distributions/python/ops/dirichlet_multinomial.py
+++ b/tensorflow/contrib/distributions/python/ops/dirichlet_multinomial.py
@@ -237,6 +237,48 @@ class DirichletMultinomial(distribution.Distribution):
mean_no_n = alpha / array_ops.expand_dims(alpha_sum, -1)
return array_ops.expand_dims(n, -1) * mean_no_n
+ def variance(self, name='mean'):
+ """Class variances for every batch member.
+
+ The variance for each batch member is defined as the following:
+
+ ```
+ Var(X_j) = n * alpha_j / alpha_0 * (1 - alpha_j / alpha_0) *
+ (n + alpha_0) / (1 + alpha_0)
+ ```
+
+ where `alpha_0 = sum_j alpha_j`.
+
+ The covariance between elements in a batch is defined as:
+
+ ```
+ Cov(X_i, X_j) = -n * alpha_i * alpha_j / alpha_0 ** 2 *
+ (n + alpha_0) / (1 + alpha_0)
+ ```
+
+ Args:
+ name: The name for this op.
+
+ Returns:
+ A `Tensor` representing the variances for each batch member.
+ """
+ alpha = self._alpha
+ alpha_sum = self._alpha_sum
+ n = self._n
+ with ops.name_scope(self.name):
+ with ops.op_scope([alpha, alpha_sum, n], name):
+ expanded_alpha_sum = array_ops.expand_dims(alpha_sum, -1)
+ shared_factor = n * (expanded_alpha_sum + n) / (
+ expanded_alpha_sum + 1) * array_ops.ones_like(alpha)
+
+ mean_no_n = alpha / expanded_alpha_sum
+ expanded_mean_no_n = array_ops.expand_dims(mean_no_n, -1)
+ variance = -math_ops.batch_matmul(
+ expanded_mean_no_n, expanded_mean_no_n, adj_y=True)
+ variance += array_ops.batch_matrix_diag(mean_no_n)
+ variance *= array_ops.expand_dims(shared_factor, -1)
+ return variance
+
def batch_shape(self, name='batch_shape'):
"""Batch dimensions of this instance as a 1-D int32 `Tensor`.