aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distributions
diff options
context:
space:
mode:
authorGravatar Dustin Tran <trandustin@google.com>2018-04-30 11:14:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-30 11:17:51 -0700
commitd6da4aa946e1f0763b9c3c2e6713c058eda0fdd4 (patch)
tree7de1a16d3dd5f7608b6195425f24efe81bc1351f /tensorflow/contrib/distributions
parent9f2728bf9b5439fd5a286a1088d7543600974d4a (diff)
Add snippet illustrating discretized logistic mixture for WaveNet.
Currently, the example manually centers the bins in order to capture ?rounding? intervals and not ?ceiling? intervals. In the future, we may simplify the example by expanding QuantizedDistribution with a binning argument. PiperOrigin-RevId: 194814662
Diffstat (limited to 'tensorflow/contrib/distributions')
-rw-r--r--tensorflow/contrib/distributions/python/ops/quantized_distribution.py64
1 files changed, 57 insertions, 7 deletions
diff --git a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py
index 1ef7651d03..eb94760ad7 100644
--- a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py
+++ b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py
@@ -128,7 +128,7 @@ The base distribution's `log_cdf` method must be defined on `y - 1`.
class QuantizedDistribution(distributions.Distribution):
"""Distribution representing the quantization `Y = ceiling(X)`.
- #### Definition in terms of sampling.
+ #### Definition in Terms of Sampling
```
1. Draw X
@@ -138,7 +138,7 @@ class QuantizedDistribution(distributions.Distribution):
5. Return Y
```
- #### Definition in terms of the probability mass function.
+ #### Definition in Terms of the Probability Mass Function
Given scalar random variable `X`, we define a discrete random variable `Y`
supported on the integers as follows:
@@ -170,12 +170,62 @@ class QuantizedDistribution(distributions.Distribution):
`P[Y = j]` is still the mass of `X` within the `jth` interval.
- #### Caveats
+ #### Examples
+
+ We illustrate a mixture of discretized logistic distributions
+ [(Salimans et al., 2017)][1]. This is used, for example, for capturing 16-bit
+ audio in WaveNet [(van den Oord et al., 2017)][2]. The values range in
+ a 1-D integer domain of `[0, 2**16-1]`, and the discretization captures
+ `P(x - 0.5 < X <= x + 0.5)` for all `x` in the domain excluding the endpoints.
+ The lowest value has probability `P(X <= 0.5)` and the highest value has
+ probability `P(2**16 - 1.5 < X)`.
+
+ Below we assume a `wavenet` function. It takes as `input` right-shifted audio
+ samples of shape `[..., sequence_length]`. It returns a real-valued tensor of
+ shape `[..., num_mixtures * 3]`, i.e., each mixture component has a `loc` and
+ `scale` parameter belonging to the logistic distribution, and a `logits`
+ parameter determining the unnormalized probability of that component.
+
+ ```python
+ tfd = tf.contrib.distributions
+ tfb = tfd.bijectors
+
+ net = wavenet(inputs)
+ loc, unconstrained_scale, logits = tf.split(net,
+ num_or_size_splits=3,
+ axis=-1)
+ scale = tf.nn.softplus(unconstrained_scale)
+
+ # Form mixture of discretized logistic distributions. Note we shift the
+ # logistic distribution by -0.5. This lets the quantization capture "rounding"
+ # intervals, `(x-0.5, x+0.5]`, and not "ceiling" intervals, `(x-1, x]`.
+ discretized_logistic_dist = tfd.QuantizedDistribution(
+ distribution=tfd.TransformedDistribution(
+ distribution=tfd.Logistic(loc=loc, scale=scale),
+ bijector=tfb.AffineScalar(shift=-0.5)),
+ low=0.,
+ high=2**16 - 1.)
+ mixture_dist = tfd.MixtureSameFamily(
+ mixture_distribution=tfd.Categorical(logits=logits),
+ components_distribution=discretized_logistic_dist)
+
+ neg_log_likelihood = -tf.reduce_sum(mixture_dist.log_prob(targets))
+ train_op = tf.train.AdamOptimizer().minimize(neg_log_likelihood)
+ ```
+
+ After instantiating `mixture_dist`, we illustrate maximum likelihood by
+ calculating its log-probability of audio samples as `target` and optimizing.
+
+ #### References
- Since evaluation of each `P[Y = j]` involves a cdf evaluation (rather than
- a closed form function such as for a Poisson), computations such as mean and
- entropy are better done with samples or approximations, and are not
- implemented by this class.
+ [1]: Tim Salimans, Andrej Karpathy, Xi Chen, and Diederik P. Kingma.
+ PixelCNN++: Improving the PixelCNN with discretized logistic mixture
+ likelihood and other modifications.
+ _International Conference on Learning Representations_, 2017.
+ https://arxiv.org/abs/1701.05517
+ [2]: Aaron van den Oord et al. Parallel WaveNet: Fast High-Fidelity Speech
+ Synthesis. _arXiv preprint arXiv:1711.10433_, 2017.
+ https://arxiv.org/abs/1711.10433
"""
def __init__(self,