diff options
Diffstat (limited to 'tensorflow/python/ops/distributions/gamma.py')
-rw-r--r-- | tensorflow/python/ops/distributions/gamma.py | 305 |
1 files changed, 305 insertions, 0 deletions
diff --git a/tensorflow/python/ops/distributions/gamma.py b/tensorflow/python/ops/distributions/gamma.py new file mode 100644 index 0000000000..4ac2b9b4ef --- /dev/null +++ b/tensorflow/python/ops/distributions/gamma.py @@ -0,0 +1,305 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The Gamma distribution class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import random_ops +from tensorflow.python.ops.distributions import distribution +from tensorflow.python.ops.distributions import kullback_leibler +from tensorflow.python.ops.distributions import util as distribution_util + + +__all__ = [ + "Gamma", + "GammaWithSoftplusConcentrationRate", +] + + +class Gamma(distribution.Distribution): + """Gamma distribution. + + The Gamma distribution is defined over positive real numbers using + parameters `concentration` (aka "alpha") and `rate` (aka "beta"). + + #### Mathematical Details + + The probability density function (pdf) is, + + ```none + pdf(x; alpha, beta, x > 0) = x**(alpha - 1) exp(-x beta) / Z + Z = Gamma(alpha) beta**alpha + ``` + + where: + + * `concentration = alpha`, `alpha > 0`, + * `rate = beta`, `beta > 0`, + * `Z` is the normalizing constant, and, + * `Gamma` is the [gamma function]( + https://en.wikipedia.org/wiki/Gamma_function). + + The cumulative density function (cdf) is, + + ```none + cdf(x; alpha, beta, x > 0) = GammaInc(alpha, beta x) / Gamma(alpha) + ``` + + where `GammaInc` is the [lower incomplete Gamma function]( + https://en.wikipedia.org/wiki/Incomplete_gamma_function). + + The parameters can be intuited via their relationship to mean and stddev, + + ```none + concentration = alpha = (mean / stddev)**2 + rate = beta = mean / stddev**2 = concentration / mean + ``` + + Distribution parameters are automatically broadcast in all functions; see + examples for details. + + WARNING: This distribution may draw 0-valued samples for small `concentration` + values. See note in `tf.random_gamma` docstring. + + #### Examples + + ```python + dist = Gamma(concentration=3.0, rate=2.0) + dist2 = Gamma(concentration=[3.0, 4.0], rate=[2.0, 3.0]) + ``` + + """ + + def __init__(self, + concentration, + rate, + validate_args=False, + allow_nan_stats=True, + name="Gamma"): + """Construct Gamma with `concentration` and `rate` parameters. + + The parameters `concentration` and `rate` must be shaped in a way that + supports broadcasting (e.g. `concentration + rate` is a valid operation). + + Args: + concentration: Floating point tensor, the concentration params of the + distribution(s). Must contain only positive values. + rate: Floating point tensor, the inverse scale params of the + distribution(s). Must contain only positive values. + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + allow_nan_stats: Python `bool`, default `True`. When `True`, statistics + (e.g., mean, mode, variance) use the value "`NaN`" to indicate the + result is undefined. When `False`, an exception is raised if one or + more of the statistic's batch members are undefined. + name: Python `str` name prefixed to Ops created by this class. + + Raises: + TypeError: if `concentration` and `rate` are different dtypes. + """ + parameters = locals() + with ops.name_scope(name, values=[concentration, rate]): + with ops.control_dependencies([ + check_ops.assert_positive(concentration), + check_ops.assert_positive(rate), + ] if validate_args else []): + self._concentration = array_ops.identity( + concentration, name="concentration") + self._rate = array_ops.identity(rate, name="rate") + check_ops.assert_same_float_dtype( + [self._concentration, self._rate]) + super(Gamma, self).__init__( + dtype=self._concentration.dtype, + validate_args=validate_args, + allow_nan_stats=allow_nan_stats, + reparameterization_type=distribution.NOT_REPARAMETERIZED, + parameters=parameters, + graph_parents=[self._concentration, + self._rate], + name=name) + + @staticmethod + def _param_shapes(sample_shape): + return dict( + zip(("concentration", "rate"), ([ops.convert_to_tensor( + sample_shape, dtype=dtypes.int32)] * 2))) + + @property + def concentration(self): + """Concentration parameter.""" + return self._concentration + + @property + def rate(self): + """Rate parameter.""" + return self._rate + + def _batch_shape_tensor(self): + return array_ops.broadcast_dynamic_shape( + array_ops.shape(self.concentration), + array_ops.shape(self.rate)) + + def _batch_shape(self): + return array_ops.broadcast_static_shape( + self.concentration.get_shape(), + self.rate.get_shape()) + + def _event_shape_tensor(self): + return constant_op.constant([], dtype=dtypes.int32) + + def _event_shape(self): + return tensor_shape.scalar() + + @distribution_util.AppendDocstring( + """Note: See `tf.random_gamma` docstring for sampling details and + caveats.""") + def _sample_n(self, n, seed=None): + return random_ops.random_gamma( + shape=[n], + alpha=self.concentration, + beta=self.rate, + dtype=self.dtype, + seed=seed) + + def _log_prob(self, x): + return self._log_unnormalized_prob(x) - self._log_normalization() + + def _prob(self, x): + return math_ops.exp(self._log_prob(x)) + + def _log_cdf(self, x): + return math_ops.log(self._cdf(x)) + + def _cdf(self, x): + x = self._maybe_assert_valid_sample(x) + # Note that igamma returns the regularized incomplete gamma function, + # which is what we want for the CDF. + return math_ops.igamma(self.concentration, self.rate * x) + + def _log_unnormalized_prob(self, x): + x = self._maybe_assert_valid_sample(x) + return (self.concentration - 1.) * math_ops.log(x) - self.rate * x + + def _log_normalization(self): + return (math_ops.lgamma(self.concentration) + - self.concentration * math_ops.log(self.rate)) + + def _entropy(self): + return (self.concentration + - math_ops.log(self.rate) + + math_ops.lgamma(self.concentration) + + ((1. - self.concentration) * + math_ops.digamma(self.concentration))) + + def _mean(self): + return self.concentration / self.rate + + def _variance(self): + return self.concentration / math_ops.square(self.rate) + + def _stddev(self): + return math_ops.sqrt(self.concentration) / self.rate + + @distribution_util.AppendDocstring( + """The mode of a gamma distribution is `(shape - 1) / rate` when + `shape > 1`, and `NaN` otherwise. If `self.allow_nan_stats` is `False`, + an exception will be raised rather than returning `NaN`.""") + def _mode(self): + mode = (self.concentration - 1.) / self.rate + if self.allow_nan_stats: + nan = array_ops.fill( + self.batch_shape_tensor(), + np.array(np.nan, dtype=self.dtype.as_numpy_dtype()), + name="nan") + return array_ops.where(self.concentration > 1., mode, nan) + else: + return control_flow_ops.with_dependencies([ + check_ops.assert_less( + array_ops.ones([], self.dtype), + self.concentration, + message="mode not defined when any concentration <= 1"), + ], mode) + + def _maybe_assert_valid_sample(self, x): + check_ops.assert_same_float_dtype(tensors=[x], dtype=self.dtype) + if not self.validate_args: + return x + return control_flow_ops.with_dependencies([ + check_ops.assert_positive(x), + ], x) + + +class GammaWithSoftplusConcentrationRate(Gamma): + """`Gamma` with softplus of `concentration` and `rate`.""" + + def __init__(self, + concentration, + rate, + validate_args=False, + allow_nan_stats=True, + name="GammaWithSoftplusConcentrationRate"): + parameters = locals() + with ops.name_scope(name, values=[concentration, rate]): + super(GammaWithSoftplusConcentrationRate, self).__init__( + concentration=nn.softplus(concentration, + name="softplus_concentration"), + rate=nn.softplus(rate, name="softplus_rate"), + validate_args=validate_args, + allow_nan_stats=allow_nan_stats, + name=name) + self._parameters = parameters + + +@kullback_leibler.RegisterKL(Gamma, Gamma) +def _kl_gamma_gamma(g0, g1, name=None): + """Calculate the batched KL divergence KL(g0 || g1) with g0 and g1 Gamma. + + Args: + g0: instance of a Gamma distribution object. + g1: instance of a Gamma distribution object. + name: (optional) Name to use for created operations. + Default is "kl_gamma_gamma". + + Returns: + kl_gamma_gamma: `Tensor`. The batchwise KL(g0 || g1). + """ + with ops.name_scope(name, "kl_gamma_gamma", values=[ + g0.concentration, g0.rate, g1.concentration, g1.rate]): + # Result from: + # http://www.fil.ion.ucl.ac.uk/~wpenny/publications/densities.ps + # For derivation see: + # http://stats.stackexchange.com/questions/11646/kullback-leibler-divergence-between-two-gamma-distributions pylint: disable=line-too-long + return (((g0.concentration - g1.concentration) + * math_ops.digamma(g0.concentration)) + + math_ops.lgamma(g1.concentration) + - math_ops.lgamma(g0.concentration) + + g1.concentration * math_ops.log(g0.rate) + - g1.concentration * math_ops.log(g1.rate) + + g0.concentration * (g1.rate / g0.rate - 1.)) |