aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/distributions/dirichlet.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/distributions/dirichlet.py')
-rw-r--r--tensorflow/python/ops/distributions/dirichlet.py297
1 files changed, 297 insertions, 0 deletions
diff --git a/tensorflow/python/ops/distributions/dirichlet.py b/tensorflow/python/ops/distributions/dirichlet.py
new file mode 100644
index 0000000000..923696a553
--- /dev/null
+++ b/tensorflow/python/ops/distributions/dirichlet.py
@@ -0,0 +1,297 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""The Dirichlet distribution class."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import special_math_ops
+from tensorflow.python.ops.distributions import distribution
+from tensorflow.python.ops.distributions import util as distribution_util
+
+
+__all__ = [
+ "Dirichlet",
+]
+
+
+_dirichlet_sample_note = """Note: `value` must be a non-negative tensor with
+dtype `self.dtype` and be in the `(self.event_shape() - 1)`-simplex, i.e.,
+`tf.reduce_sum(value, -1) = 1`. It must have a shape compatible with
+`self.batch_shape() + self.event_shape()`."""
+
+
+class Dirichlet(distribution.Distribution):
+ """Dirichlet distribution.
+
+ The Dirichlet distribution is defined over the
+ [`(k-1)`-simplex](https://en.wikipedia.org/wiki/Simplex) using a positive,
+ length-`k` vector `concentration` (`k > 1`). The Dirichlet is identically the
+ Beta distribution when `k = 2`.
+
+ #### Mathematical Details
+
+ The Dirichlet is a distribution over the open `(k-1)`-simplex, i.e.,
+
+ ```none
+ S^{k-1} = { (x_0, ..., x_{k-1}) in R^k : sum_j x_j = 1 and all_j x_j > 0 }.
+ ```
+
+ The probability density function (pdf) is,
+
+ ```none
+ pdf(x; alpha) = prod_j x_j**(alpha_j - 1) / Z
+ Z = prod_j Gamma(alpha_j) / Gamma(sum_j alpha_j)
+ ```
+
+ where:
+
+ * `x in S^{k-1}`, i.e., the `(k-1)`-simplex,
+ * `concentration = alpha = [alpha_0, ..., alpha_{k-1}]`, `alpha_j > 0`,
+ * `Z` is the normalization constant aka the [multivariate beta function](
+ https://en.wikipedia.org/wiki/Beta_function#Multivariate_beta_function),
+ and,
+ * `Gamma` is the [gamma function](
+ https://en.wikipedia.org/wiki/Gamma_function).
+
+ The `concentration` represents mean total counts of class occurrence, i.e.,
+
+ ```none
+ concentration = alpha = mean * total_concentration
+ ```
+
+ where `mean` in `S^{k-1}` and `total_concentration` is a positive real number
+ representing a mean total count.
+
+ Distribution parameters are automatically broadcast in all functions; see
+ examples for details.
+
+ #### Examples
+
+ ```python
+ # Create a single trivariate Dirichlet, with the 3rd class being three times
+ # more frequent than the first. I.e., batch_shape=[], event_shape=[3].
+ alpha = [1., 2, 3]
+ dist = Dirichlet(alpha)
+
+ dist.sample([4, 5]) # shape: [4, 5, 3]
+
+ # x has one sample, one batch, three classes:
+ x = [.2, .3, .5] # shape: [3]
+ dist.prob(x) # shape: []
+
+ # x has two samples from one batch:
+ x = [[.1, .4, .5],
+ [.2, .3, .5]]
+ dist.prob(x) # shape: [2]
+
+ # alpha will be broadcast to shape [5, 7, 3] to match x.
+ x = [[...]] # shape: [5, 7, 3]
+ dist.prob(x) # shape: [5, 7]
+ ```
+
+ ```python
+ # Create batch_shape=[2], event_shape=[3]:
+ alpha = [[1., 2, 3],
+ [4, 5, 6]] # shape: [2, 3]
+ dist = Dirichlet(alpha)
+
+ dist.sample([4, 5]) # shape: [4, 5, 2, 3]
+
+ x = [.2, .3, .5]
+ # x will be broadcast as [[.2, .3, .5],
+ # [.2, .3, .5]],
+ # thus matching batch_shape [2, 3].
+ dist.prob(x) # shape: [2]
+ ```
+
+ """
+
+ def __init__(self,
+ concentration,
+ validate_args=False,
+ allow_nan_stats=True,
+ name="Dirichlet"):
+ """Initialize a batch of Dirichlet distributions.
+
+ Args:
+ concentration: Positive floating-point `Tensor` indicating mean number
+ of class occurrences; aka "alpha". Implies `self.dtype`, and
+ `self.batch_shape`, `self.event_shape`, i.e., if
+ `concentration.shape = [N1, N2, ..., Nm, k]` then
+ `batch_shape = [N1, N2, ..., Nm]` and
+ `event_shape = [k]`.
+ validate_args: Python `bool`, default `False`. When `True` distribution
+ parameters are checked for validity despite possibly degrading runtime
+ performance. When `False` invalid inputs may silently render incorrect
+ outputs.
+ allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
+ (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
+ result is undefined. When `False`, an exception is raised if one or
+ more of the statistic's batch members are undefined.
+ name: Python `str` name prefixed to Ops created by this class.
+ """
+ parameters = locals()
+ with ops.name_scope(name, values=[concentration]):
+ self._concentration = self._maybe_assert_valid_concentration(
+ ops.convert_to_tensor(concentration, name="concentration"),
+ validate_args)
+ self._total_concentration = math_ops.reduce_sum(self._concentration, -1)
+ super(Dirichlet, self).__init__(
+ dtype=self._concentration.dtype,
+ validate_args=validate_args,
+ allow_nan_stats=allow_nan_stats,
+ reparameterization_type=distribution.NOT_REPARAMETERIZED,
+ parameters=parameters,
+ graph_parents=[self._concentration,
+ self._total_concentration],
+ name=name)
+
+ @property
+ def concentration(self):
+ """Concentration parameter; expected counts for that coordinate."""
+ return self._concentration
+
+ @property
+ def total_concentration(self):
+ """Sum of last dim of concentration parameter."""
+ return self._total_concentration
+
+ def _batch_shape_tensor(self):
+ return array_ops.shape(self.total_concentration)
+
+ def _batch_shape(self):
+ return self.total_concentration.get_shape()
+
+ def _event_shape_tensor(self):
+ return array_ops.shape(self.concentration)[-1:]
+
+ def _event_shape(self):
+ return self.concentration.get_shape().with_rank_at_least(1)[-1:]
+
+ def _sample_n(self, n, seed=None):
+ gamma_sample = random_ops.random_gamma(
+ shape=[n],
+ alpha=self.concentration,
+ dtype=self.dtype,
+ seed=seed)
+ return gamma_sample / math_ops.reduce_sum(gamma_sample, -1, keep_dims=True)
+
+ @distribution_util.AppendDocstring(_dirichlet_sample_note)
+ def _log_prob(self, x):
+ return self._log_unnormalized_prob(x) - self._log_normalization()
+
+ @distribution_util.AppendDocstring(_dirichlet_sample_note)
+ def _prob(self, x):
+ return math_ops.exp(self._log_prob(x))
+
+ def _log_unnormalized_prob(self, x):
+ x = self._maybe_assert_valid_sample(x)
+ return math_ops.reduce_sum((self.concentration - 1.) * math_ops.log(x), -1)
+
+ def _log_normalization(self):
+ return special_math_ops.lbeta(self.concentration)
+
+ def _entropy(self):
+ k = math_ops.cast(self.event_shape_tensor()[0], self.dtype)
+ return (
+ self._log_normalization()
+ + ((self.total_concentration - k)
+ * math_ops.digamma(self.total_concentration))
+ - math_ops.reduce_sum(
+ (self.concentration - 1.) * math_ops.digamma(self.concentration),
+ axis=-1))
+
+ def _mean(self):
+ return self.concentration / self.total_concentration[..., array_ops.newaxis]
+
+ def _covariance(self):
+ x = self._variance_scale_term() * self._mean()
+ return array_ops.matrix_set_diag(
+ -math_ops.matmul(x[..., array_ops.newaxis],
+ x[..., array_ops.newaxis, :]), # outer prod
+ self._variance())
+
+ def _variance(self):
+ scale = self._variance_scale_term()
+ x = scale * self._mean()
+ return x * (scale - x)
+
+ def _variance_scale_term(self):
+ """Helper to `_covariance` and `_variance` which computes a shared scale."""
+ return math_ops.rsqrt(1. + self.total_concentration[..., array_ops.newaxis])
+
+ @distribution_util.AppendDocstring(
+ """Note: The mode is undefined when any `concentration <= 1`. If
+ `self.allow_nan_stats` is `True`, `NaN` is used for undefined modes. If
+ `self.allow_nan_stats` is `False` an exception is raised when one or more
+ modes are undefined.""")
+ def _mode(self):
+ k = math_ops.cast(self.event_shape_tensor()[0], self.dtype)
+ mode = (self.concentration - 1.) / (
+ self.total_concentration[..., array_ops.newaxis] - k)
+ if self.allow_nan_stats:
+ nan = array_ops.fill(
+ array_ops.shape(mode),
+ np.array(np.nan, dtype=self.dtype.as_numpy_dtype()),
+ name="nan")
+ return array_ops.where(
+ math_ops.reduce_all(self.concentration > 1., axis=-1),
+ mode, nan)
+ return control_flow_ops.with_dependencies([
+ check_ops.assert_less(
+ array_ops.ones([], self.dtype),
+ self.concentration,
+ message="Mode undefined when any concentration <= 1"),
+ ], mode)
+
+ def _maybe_assert_valid_concentration(self, concentration, validate_args):
+ """Checks the validity of the concentration parameter."""
+ if not validate_args:
+ return concentration
+ return control_flow_ops.with_dependencies([
+ check_ops.assert_positive(
+ concentration,
+ message="Concentration parameter must be positive."),
+ check_ops.assert_rank_at_least(
+ concentration, 1,
+ message="Concentration parameter must have >=1 dimensions."),
+ check_ops.assert_less(
+ 1, array_ops.shape(concentration)[-1],
+ message="Concentration parameter must have event_size >= 2."),
+ ], concentration)
+
+ def _maybe_assert_valid_sample(self, x):
+ """Checks the validity of a sample."""
+ if not self.validate_args:
+ return x
+ return control_flow_ops.with_dependencies([
+ check_ops.assert_positive(
+ x,
+ message="samples must be positive"),
+ distribution_util.assert_close(
+ array_ops.ones([], dtype=self.dtype),
+ math_ops.reduce_sum(x, -1),
+ message="sample last-dimension must sum to `1`"),
+ ], x)