diff options
Diffstat (limited to 'tensorflow/python/ops/distributions/beta.py')
-rw-r--r-- | tensorflow/python/ops/distributions/beta.py | 366 |
1 files changed, 366 insertions, 0 deletions
diff --git a/tensorflow/python/ops/distributions/beta.py b/tensorflow/python/ops/distributions/beta.py new file mode 100644 index 0000000000..2b93478cdf --- /dev/null +++ b/tensorflow/python/ops/distributions/beta.py @@ -0,0 +1,366 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The Beta distribution class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import random_ops +from tensorflow.python.ops.distributions import distribution +from tensorflow.python.ops.distributions import kullback_leibler +from tensorflow.python.ops.distributions import util as distribution_util + + +__all__ = [ + "Beta", + "BetaWithSoftplusConcentration", +] + + +_beta_sample_note = """Note: `x` must have dtype `self.dtype` and be in +`[0, 1].` It must have a shape compatible with `self.batch_shape()`.""" + + +class Beta(distribution.Distribution): + """Beta distribution. + + The Beta distribution is defined over the `(0, 1)` interval using parameters + `concentration1` (aka "alpha") and `concentration0` (aka "beta"). + + #### Mathematical Details + + The probability density function (pdf) is, + + ```none + pdf(x; alpha, beta) = x**(alpha - 1) (1 - x)**(beta - 1) / Z + Z = Gamma(alpha) Gamma(beta) / Gamma(alpha + beta) + ``` + + where: + + * `concentration1 = alpha`, + * `concentration0 = beta`, + * `Z` is the normalization constant, and, + * `Gamma` is the [gamma function]( + https://en.wikipedia.org/wiki/Gamma_function). + + The concentration parameters represent mean total counts of a `1` or a `0`, + i.e., + + ```none + concentration1 = alpha = mean * total_concentration + concentration0 = beta = (1. - mean) * total_concentration + ``` + + where `mean` in `(0, 1)` and `total_concentration` is a positive real number + representing a mean `total_count = concentration1 + concentration0`. + + Distribution parameters are automatically broadcast in all functions; see + examples for details. + + #### Examples + + ```python + # Create a batch of three Beta distributions. + alpha = [1, 2, 3] + beta = [1, 2, 3] + dist = Beta(alpha, beta) + + dist.sample([4, 5]) # Shape [4, 5, 3] + + # `x` has three batch entries, each with two samples. + x = [[.1, .4, .5], + [.2, .3, .5]] + # Calculate the probability of each pair of samples under the corresponding + # distribution in `dist`. + dist.prob(x) # Shape [2, 3] + ``` + + ```python + # Create batch_shape=[2, 3] via parameter broadcast: + alpha = [[1.], [2]] # Shape [2, 1] + beta = [3., 4, 5] # Shape [3] + dist = Beta(alpha, beta) + + # alpha broadcast as: [[1., 1, 1,], + # [2, 2, 2]] + # beta broadcast as: [[3., 4, 5], + # [3, 4, 5]] + # batch_Shape [2, 3] + dist.sample([4, 5]) # Shape [4, 5, 2, 3] + + x = [.2, .3, .5] + # x will be broadcast as [[.2, .3, .5], + # [.2, .3, .5]], + # thus matching batch_shape [2, 3]. + dist.prob(x) # Shape [2, 3] + ``` + + """ + + def __init__(self, + concentration1=None, + concentration0=None, + validate_args=False, + allow_nan_stats=True, + name="Beta"): + """Initialize a batch of Beta distributions. + + Args: + concentration1: Positive floating-point `Tensor` indicating mean + number of successes; aka "alpha". Implies `self.dtype` and + `self.batch_shape`, i.e., + `concentration1.shape = [N1, N2, ..., Nm] = self.batch_shape`. + concentration0: Positive floating-point `Tensor` indicating mean + number of failures; aka "beta". Otherwise has same semantics as + `concentration1`. + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + allow_nan_stats: Python `bool`, default `True`. When `True`, statistics + (e.g., mean, mode, variance) use the value "`NaN`" to indicate the + result is undefined. When `False`, an exception is raised if one or + more of the statistic's batch members are undefined. + name: Python `str` name prefixed to Ops created by this class. + """ + parameters = locals() + with ops.name_scope(name, values=[concentration1, concentration0]): + self._concentration1 = self._maybe_assert_valid_concentration( + ops.convert_to_tensor(concentration1, name="concentration1"), + validate_args) + self._concentration0 = self._maybe_assert_valid_concentration( + ops.convert_to_tensor(concentration0, name="concentration0"), + validate_args) + check_ops.assert_same_float_dtype([ + self._concentration1, self._concentration0]) + self._total_concentration = self._concentration1 + self._concentration0 + super(Beta, self).__init__( + dtype=self._total_concentration.dtype, + validate_args=validate_args, + allow_nan_stats=allow_nan_stats, + reparameterization_type=distribution.NOT_REPARAMETERIZED, + parameters=parameters, + graph_parents=[self._concentration1, + self._concentration0, + self._total_concentration], + name=name) + + @staticmethod + def _param_shapes(sample_shape): + return dict(zip( + ["concentration1", "concentration0"], + [ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)] * 2)) + + @property + def concentration1(self): + """Concentration parameter associated with a `1` outcome.""" + return self._concentration1 + + @property + def concentration0(self): + """Concentration parameter associated with a `0` outcome.""" + return self._concentration0 + + @property + def total_concentration(self): + """Sum of concentration parameters.""" + return self._total_concentration + + def _batch_shape_tensor(self): + return array_ops.shape(self.total_concentration) + + def _batch_shape(self): + return self.total_concentration.get_shape() + + def _event_shape_tensor(self): + return constant_op.constant([], dtype=dtypes.int32) + + def _event_shape(self): + return tensor_shape.scalar() + + def _sample_n(self, n, seed=None): + expanded_concentration1 = array_ops.ones_like( + self.total_concentration, dtype=self.dtype) * self.concentration1 + expanded_concentration0 = array_ops.ones_like( + self.total_concentration, dtype=self.dtype) * self.concentration0 + gamma1_sample = random_ops.random_gamma( + shape=[n], + alpha=expanded_concentration1, + dtype=self.dtype, + seed=seed) + gamma2_sample = random_ops.random_gamma( + shape=[n], + alpha=expanded_concentration0, + dtype=self.dtype, + seed=distribution_util.gen_new_seed(seed, "beta")) + beta_sample = gamma1_sample / (gamma1_sample + gamma2_sample) + return beta_sample + + @distribution_util.AppendDocstring(_beta_sample_note) + def _log_prob(self, x): + return self._log_unnormalized_prob(x) - self._log_normalization() + + @distribution_util.AppendDocstring(_beta_sample_note) + def _prob(self, x): + return math_ops.exp(self._log_prob(x)) + + @distribution_util.AppendDocstring(_beta_sample_note) + def _log_cdf(self, x): + return math_ops.log(self._cdf(x)) + + @distribution_util.AppendDocstring(_beta_sample_note) + def _cdf(self, x): + return math_ops.betainc(self.concentration1, self.concentration0, x) + + def _log_unnormalized_prob(self, x): + x = self._maybe_assert_valid_sample(x) + return ((self.concentration1 - 1.) * math_ops.log(x) + + (self.concentration0 - 1.) * math_ops.log1p(-x)) + + def _log_normalization(self): + return (math_ops.lgamma(self.concentration1) + + math_ops.lgamma(self.concentration0) + - math_ops.lgamma(self.total_concentration)) + + def _entropy(self): + return ( + self._log_normalization() + - (self.concentration1 - 1.) * math_ops.digamma(self.concentration1) + - (self.concentration0 - 1.) * math_ops.digamma(self.concentration0) + + ((self.total_concentration - 2.) * + math_ops.digamma(self.total_concentration))) + + def _mean(self): + return self._concentration1 / self._total_concentration + + def _variance(self): + return self._mean() * (1. - self._mean()) / (1. + self.total_concentration) + + @distribution_util.AppendDocstring( + """Note: The mode is undefined when `concentration1 <= 1` or + `concentration0 <= 1`. If `self.allow_nan_stats` is `True`, `NaN` + is used for undefined modes. If `self.allow_nan_stats` is `False` an + exception is raised when one or more modes are undefined.""") + def _mode(self): + mode = (self.concentration1 - 1.) / (self.total_concentration - 2.) + if self.allow_nan_stats: + nan = array_ops.fill( + self.batch_shape_tensor(), + np.array(np.nan, dtype=self.dtype.as_numpy_dtype()), + name="nan") + is_defined = math_ops.logical_and(self.concentration1 > 1., + self.concentration0 > 1.) + return array_ops.where(is_defined, mode, nan) + return control_flow_ops.with_dependencies([ + check_ops.assert_less( + array_ops.ones([], dtype=self.dtype), + self.concentration1, + message="Mode undefined for concentration1 <= 1."), + check_ops.assert_less( + array_ops.ones([], dtype=self.dtype), + self.concentration0, + message="Mode undefined for concentration0 <= 1.") + ], mode) + + def _maybe_assert_valid_concentration(self, concentration, validate_args): + """Checks the validity of a concentration parameter.""" + if not validate_args: + return concentration + return control_flow_ops.with_dependencies([ + check_ops.assert_positive( + concentration, + message="Concentration parameter must be positive."), + ], concentration) + + def _maybe_assert_valid_sample(self, x): + """Checks the validity of a sample.""" + if not self.validate_args: + return x + return control_flow_ops.with_dependencies([ + check_ops.assert_positive( + x, + message="sample must be positive"), + check_ops.assert_less( + x, array_ops.ones([], self.dtype), + message="sample must be no larger than `1`."), + ], x) + + +class BetaWithSoftplusConcentration(Beta): + """Beta with softplus transform of `concentration1` and `concentration0`.""" + + def __init__(self, + concentration1, + concentration0, + validate_args=False, + allow_nan_stats=True, + name="BetaWithSoftplusConcentration"): + parameters = locals() + with ops.name_scope(name, values=[concentration1, + concentration0]) as ns: + super(BetaWithSoftplusConcentration, self).__init__( + concentration1=nn.softplus(concentration1, + name="softplus_concentration1"), + concentration0=nn.softplus(concentration0, + name="softplus_concentration0"), + validate_args=validate_args, + allow_nan_stats=allow_nan_stats, + name=ns) + self._parameters = parameters + + +@kullback_leibler.RegisterKL(Beta, Beta) +def _kl_beta_beta(d1, d2, name=None): + """Calculate the batchwise KL divergence KL(d1 || d2) with d1 and d2 Beta. + + Args: + d1: instance of a Beta distribution object. + d2: instance of a Beta distribution object. + name: (optional) Name to use for created operations. + default is "kl_beta_beta". + + Returns: + Batchwise KL(d1 || d2) + """ + def delta(fn, is_property=True): + fn1 = getattr(d1, fn) + fn2 = getattr(d2, fn) + return (fn2 - fn1) if is_property else (fn2() - fn1()) + with ops.name_scope(name, "kl_beta_beta", values=[ + d1.concentration1, + d1.concentration0, + d1.total_concentration, + d2.concentration1, + d2.concentration0, + d2.total_concentration, + ]): + return (delta("_log_normalization", is_property=False) + - math_ops.digamma(d1.concentration1) * delta("concentration1") + - math_ops.digamma(d1.concentration0) * delta("concentration0") + + (math_ops.digamma(d1.total_concentration) + * delta("total_concentration"))) |