diff options
Diffstat (limited to 'tensorflow/contrib/distributions/python/ops/transformed_distribution.py')
-rw-r--r-- | tensorflow/contrib/distributions/python/ops/transformed_distribution.py | 76 |
1 files changed, 51 insertions, 25 deletions
diff --git a/tensorflow/contrib/distributions/python/ops/transformed_distribution.py b/tensorflow/contrib/distributions/python/ops/transformed_distribution.py index 1403adbda2..844f78ca96 100644 --- a/tensorflow/contrib/distributions/python/ops/transformed_distribution.py +++ b/tensorflow/contrib/distributions/python/ops/transformed_distribution.py @@ -19,11 +19,9 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib.distributions.python.ops import distribution as distribution_lib +from tensorflow.contrib.distributions.python.ops import distribution as distributions from tensorflow.contrib.distributions.python.ops import distribution_util -# Bijectors must be directly imported because `remove_undocumented` prevents -# individual file imports. -from tensorflow.contrib.distributions.python.ops.bijectors.identity import Identity +from tensorflow.contrib.distributions.python.ops.bijectors import identity as identity_lib from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -121,7 +119,7 @@ def _is_scalar_from_shape(shape): return _logical_equal(_ndims_from_shape(shape), 0) -class TransformedDistribution(distribution_lib.Distribution): +class TransformedDistribution(distributions.Distribution): """A Transformed Distribution. A `TransformedDistribution` models `p(y)` given a base distribution `p(x)`, @@ -148,19 +146,49 @@ class TransformedDistribution(distribution_lib.Distribution): A `TransformedDistribution` implements the following operations: - * `sample` - Mathematically: `Y = g(X)` - Programmatically: `bijector.forward(distribution.sample(...))` + * `sample`: - * `log_prob` - Mathematically: `(log o pdf)(Y=y) = (log o pdf o g^{-1})(y) - + (log o abs o det o J o g^{-1})(y)` - Programmatically: `(distribution.log_prob(bijector.inverse(y)) - + bijector.inverse_log_det_jacobian(y))` + Mathematically: - * `log_cdf` - Mathematically: `(log o cdf)(Y=y) = (log o cdf o g^{-1})(y)` - Programmatically: `distribution.log_cdf(bijector.inverse(x))` + ```none + Y = g(X) + ``` + + Programmatically: + + ```python + return bijector.forward(distribution.sample(...)) + ``` + + * `log_prob`: + + Mathematically: + + ```none + (log o pdf)(Y=y) = (log o pdf o g^{-1})(y) + + (log o abs o det o J o g^{-1})(y) + ``` + + Programmatically: + + ```python + return (distribution.log_prob(bijector.inverse(y)) + + bijector.inverse_log_det_jacobian(y)) + ``` + + * `log_cdf`: + + Mathematically: + + ```none + (log o cdf)(Y=y) = (log o cdf o g^{-1})(y) + ``` + + Programmatically: + + ```python + return distribution.log_cdf(bijector.inverse(x)) + ``` * and similarly for: `cdf`, `prob`, `log_survival_function`, `survival_function`. @@ -171,7 +199,7 @@ class TransformedDistribution(distribution_lib.Distribution): ```python ds = tf.contrib.distributions log_normal = ds.TransformedDistribution( - distribution=ds.Normal(loc=0., scale=1.), + distribution=ds.Normal(loc=mu, scale=sigma), bijector=ds.bijectors.Exp(), name="LogNormalTransformedDistribution") ``` @@ -181,7 +209,7 @@ class TransformedDistribution(distribution_lib.Distribution): ```python ds = tf.contrib.distributions log_normal = ds.TransformedDistribution( - distribution=ds.Normal(loc=0., scale=1.), + distribution=ds.Normal(loc=mu, scale=sigma), bijector=ds.bijectors.Inline( forward_fn=tf.exp, inverse_fn=tf.log, @@ -195,11 +223,8 @@ class TransformedDistribution(distribution_lib.Distribution): ```python ds = tf.contrib.distributions normal = ds.TransformedDistribution( - distribution=ds.Normal(loc=0., scale=1.), - bijector=ds.bijectors.Affine( - shift=-1., - scale_identity_multiplier=2., - event_ndims=0), + distribution=ds.Normal(loc=0, scale=1), + bijector=ds.bijectors.ScaleAndShift(loc=mu, scale=sigma, event_ndims=0), name="NormalTransformedDistribution") ``` @@ -212,6 +237,7 @@ class TransformedDistribution(distribution_lib.Distribution): multivariate Normal as a `TransformedDistribution`. ```python + bs = tf.contrib.distributions.bijector ds = tf.contrib.distributions # We will create two MVNs with batch_shape = event_shape = 2. mean = [[-1., 0], # batch:0 @@ -222,7 +248,7 @@ class TransformedDistribution(distribution_lib.Distribution): [2, 2]]] # batch:1 mvn1 = ds.TransformedDistribution( distribution=ds.Normal(loc=0., scale=1.), - bijector=ds.bijectors.Affine(shift=mean, scale_tril=chol_cov), + bijector=bs.Affine(shift=mean, tril=chol_cov), batch_shape=[2], # Valid because base_distribution.batch_shape == []. event_shape=[2]) # Valid because base_distribution.event_shape == []. mvn2 = ds.MultivariateNormalTriL(loc=mean, scale_tril=chol_cov) @@ -265,7 +291,7 @@ class TransformedDistribution(distribution_lib.Distribution): self._empty = constant_op.constant([], dtype=dtypes.int32, name="empty") if bijector is None: - bijector = Identity(validate_args=validate_args) + bijector = identity_lib.Identity(validate_args=validate_args) # We will keep track of a static and dynamic version of # self._is_{batch,event}_override. This way we can do more prior to graph |