diff options
-rw-r--r-- | tensorflow/contrib/distributions/python/ops/quantized_distribution.py | 64 |
1 files changed, 57 insertions, 7 deletions
diff --git a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py index 1ef7651d03..eb94760ad7 100644 --- a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py +++ b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py @@ -128,7 +128,7 @@ The base distribution's `log_cdf` method must be defined on `y - 1`. class QuantizedDistribution(distributions.Distribution): """Distribution representing the quantization `Y = ceiling(X)`. - #### Definition in terms of sampling. + #### Definition in Terms of Sampling ``` 1. Draw X @@ -138,7 +138,7 @@ class QuantizedDistribution(distributions.Distribution): 5. Return Y ``` - #### Definition in terms of the probability mass function. + #### Definition in Terms of the Probability Mass Function Given scalar random variable `X`, we define a discrete random variable `Y` supported on the integers as follows: @@ -170,12 +170,62 @@ class QuantizedDistribution(distributions.Distribution): `P[Y = j]` is still the mass of `X` within the `jth` interval. - #### Caveats + #### Examples + + We illustrate a mixture of discretized logistic distributions + [(Salimans et al., 2017)][1]. This is used, for example, for capturing 16-bit + audio in WaveNet [(van den Oord et al., 2017)][2]. The values range in + a 1-D integer domain of `[0, 2**16-1]`, and the discretization captures + `P(x - 0.5 < X <= x + 0.5)` for all `x` in the domain excluding the endpoints. + The lowest value has probability `P(X <= 0.5)` and the highest value has + probability `P(2**16 - 1.5 < X)`. + + Below we assume a `wavenet` function. It takes as `input` right-shifted audio + samples of shape `[..., sequence_length]`. It returns a real-valued tensor of + shape `[..., num_mixtures * 3]`, i.e., each mixture component has a `loc` and + `scale` parameter belonging to the logistic distribution, and a `logits` + parameter determining the unnormalized probability of that component. + + ```python + tfd = tf.contrib.distributions + tfb = tfd.bijectors + + net = wavenet(inputs) + loc, unconstrained_scale, logits = tf.split(net, + num_or_size_splits=3, + axis=-1) + scale = tf.nn.softplus(unconstrained_scale) + + # Form mixture of discretized logistic distributions. Note we shift the + # logistic distribution by -0.5. This lets the quantization capture "rounding" + # intervals, `(x-0.5, x+0.5]`, and not "ceiling" intervals, `(x-1, x]`. + discretized_logistic_dist = tfd.QuantizedDistribution( + distribution=tfd.TransformedDistribution( + distribution=tfd.Logistic(loc=loc, scale=scale), + bijector=tfb.AffineScalar(shift=-0.5)), + low=0., + high=2**16 - 1.) + mixture_dist = tfd.MixtureSameFamily( + mixture_distribution=tfd.Categorical(logits=logits), + components_distribution=discretized_logistic_dist) + + neg_log_likelihood = -tf.reduce_sum(mixture_dist.log_prob(targets)) + train_op = tf.train.AdamOptimizer().minimize(neg_log_likelihood) + ``` + + After instantiating `mixture_dist`, we illustrate maximum likelihood by + calculating its log-probability of audio samples as `target` and optimizing. + + #### References - Since evaluation of each `P[Y = j]` involves a cdf evaluation (rather than - a closed form function such as for a Poisson), computations such as mean and - entropy are better done with samples or approximations, and are not - implemented by this class. + [1]: Tim Salimans, Andrej Karpathy, Xi Chen, and Diederik P. Kingma. + PixelCNN++: Improving the PixelCNN with discretized logistic mixture + likelihood and other modifications. + _International Conference on Learning Representations_, 2017. + https://arxiv.org/abs/1701.05517 + [2]: Aaron van den Oord et al. Parallel WaveNet: Fast High-Fidelity Speech + Synthesis. _arXiv preprint arXiv:1711.10433_, 2017. + https://arxiv.org/abs/1711.10433 """ def __init__(self, |