aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/bayesflow
diff options
context:
space:
mode:
authorGravatar Dustin Tran <trandustin@google.com>2017-12-15 16:38:28 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-15 16:41:59 -0800
commit4f4abcacedcba5430e03320f39205d2f327df2ac (patch)
treef8f8b4cef612436b342b1db08799adf6f41dade6 /tensorflow/contrib/bayesflow
parentf3df9fcaefeb3ab0fd83f255bec93e1a3c013a5e (diff)
Restandardize `DenseVariational` as simpler template for other probabilistic layers.
PiperOrigin-RevId: 179255435
Diffstat (limited to 'tensorflow/contrib/bayesflow')
-rw-r--r--tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py55
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/layers_dense_variational_impl.py501
2 files changed, 194 insertions, 362 deletions
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py
index 50358fd1c2..7b5b2fec1e 100644
--- a/tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py
+++ b/tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.bayesflow.python.ops import layers_dense_variational_impl as prob_layers_lib
+from tensorflow.contrib.distributions.python.ops import independent as independent_lib
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
@@ -41,7 +42,7 @@ class Counter(object):
return self._value
-class MockDistribution(normal_lib.Normal):
+class MockDistribution(independent_lib.Independent):
"""Monitors DenseVariational calls to the underlying distribution."""
def __init__(self, result_sample, result_log_prob, loc=None, scale=None):
@@ -49,6 +50,10 @@ class MockDistribution(normal_lib.Normal):
self.result_log_prob = result_log_prob
self.result_loc = loc
self.result_scale = scale
+ self.result_distribution = normal_lib.Normal(loc=0.0, scale=1.0)
+ if loc is not None and scale is not None:
+ self.result_distribution = normal_lib.Normal(loc=self.result_loc,
+ scale=self.result_scale)
self.called_log_prob = Counter()
self.called_sample = Counter()
self.called_loc = Counter()
@@ -63,6 +68,10 @@ class MockDistribution(normal_lib.Normal):
return self.result_sample
@property
+ def distribution(self): # for dummy check on Independent(Normal)
+ return self.result_distribution
+
+ @property
def loc(self):
self.called_loc()
return self.result_loc
@@ -95,16 +104,16 @@ class DenseVariationalLocalReparametrization(test.TestCase):
inputs = random_ops.random_uniform([2, 3], seed=1)
# No keys.
- loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
- self.assertEqual(len(loss_keys), 0)
- self.assertListEqual(dense_vi.losses, loss_keys)
+ losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
+ self.assertEqual(len(losses), 0)
+ self.assertListEqual(dense_vi.losses, losses)
_ = dense_vi(inputs)
# Yes keys.
- loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
- self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(dense_vi.losses, loss_keys)
+ losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
+ self.assertEqual(len(losses), 1)
+ self.assertListEqual(dense_vi.losses, losses)
def testKLPenaltyBoth(self):
def _make_normal(dtype, *args): # pylint: disable=unused-argument
@@ -118,16 +127,16 @@ class DenseVariationalLocalReparametrization(test.TestCase):
inputs = random_ops.random_uniform([2, 3], seed=1)
# No keys.
- loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
- self.assertEqual(len(loss_keys), 0)
- self.assertListEqual(dense_vi.losses, loss_keys)
+ losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
+ self.assertEqual(len(losses), 0)
+ self.assertListEqual(dense_vi.losses, losses)
_ = dense_vi(inputs)
# Yes keys.
- loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
- self.assertEqual(len(loss_keys), 2)
- self.assertListEqual(dense_vi.losses, loss_keys)
+ losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
+ self.assertEqual(len(losses), 2)
+ self.assertListEqual(dense_vi.losses, losses)
def testVariationalNonLocal(self):
batch_size, in_size, out_size = 2, 3, 4
@@ -183,9 +192,9 @@ class DenseVariationalLocalReparametrization(test.TestCase):
expected_bias_divergence_, actual_bias_divergence_,
] = sess.run([
expected_outputs, outputs,
- kernel_posterior.result_sample, dense_vi.kernel.posterior_tensor,
+ kernel_posterior.result_sample, dense_vi.kernel_posterior_tensor,
kernel_divergence.result, kl_penalty[0],
- bias_posterior.result_sample, dense_vi.bias.posterior_tensor,
+ bias_posterior.result_sample, dense_vi.bias_posterior_tensor,
bias_divergence.result, kl_penalty[1],
])
@@ -206,11 +215,15 @@ class DenseVariationalLocalReparametrization(test.TestCase):
rtol=1e-6, atol=0.)
self.assertAllEqual(
- [[kernel_posterior, kernel_prior, kernel_posterior.result_sample]],
+ [[kernel_posterior.distribution,
+ kernel_prior.distribution,
+ kernel_posterior.result_sample]],
kernel_divergence.args)
self.assertAllEqual(
- [[bias_posterior, bias_prior, bias_posterior.result_sample]],
+ [[bias_posterior.distribution,
+ bias_prior.distribution,
+ bias_posterior.result_sample]],
bias_divergence.args)
def testVariationalLocal(self):
@@ -274,7 +287,7 @@ class DenseVariationalLocalReparametrization(test.TestCase):
] = sess.run([
expected_outputs, outputs,
kernel_divergence.result, kl_penalty[0],
- bias_posterior.result_sample, dense_vi.bias.posterior_tensor,
+ bias_posterior.result_sample, dense_vi.bias_posterior_tensor,
bias_divergence.result, kl_penalty[1],
])
@@ -292,11 +305,13 @@ class DenseVariationalLocalReparametrization(test.TestCase):
rtol=1e-6, atol=0.)
self.assertAllEqual(
- [[kernel_posterior, kernel_prior, None]],
+ [[kernel_posterior.distribution, kernel_prior.distribution, None]],
kernel_divergence.args)
self.assertAllEqual(
- [[bias_posterior, bias_prior, bias_posterior.result_sample]],
+ [[bias_posterior.distribution,
+ bias_prior.distribution,
+ bias_posterior.result_sample]],
bias_divergence.args)
diff --git a/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational_impl.py b/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational_impl.py
index b05ce0ffc1..a3b22f334a 100644
--- a/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational_impl.py
+++ b/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational_impl.py
@@ -28,10 +28,12 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.distributions.python.ops import deterministic as deterministic_lib
+from tensorflow.contrib.distributions.python.ops import independent as independent_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.layers import base as layers_lib
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import nn_ops
@@ -187,48 +189,34 @@ def default_mean_field_normal_fn(
loc_constraint,
untransformed_scale_constraint)
def _fn(dtype, shape, name, trainable, add_variable_fn):
- """Creates a batch of `Deterministic` or `Normal` distributions."""
+ """Creates multivariate `Deterministic` or `Normal` distribution."""
loc, scale = loc_scale_fn_(dtype, shape, name, trainable, add_variable_fn)
if scale is None:
- return deterministic_lib.Deterministic(loc=loc)
- return normal_lib.Normal(loc=loc, scale=scale)
+ dist = deterministic_lib.Deterministic(loc=loc)
+ else:
+ dist = normal_lib.Normal(loc=loc, scale=scale)
+ reinterpreted_batch_ndims = array_ops.shape(dist.batch_shape_tensor())[0]
+ return independent_lib.Independent(
+ dist, reinterpreted_batch_ndims=reinterpreted_batch_ndims)
return _fn
class DenseVariational(layers_lib.Layer):
"""Densely-connected variational class.
- This layer implements the Bayesian variational inference analogue to:
- `outputs = activation(matmul(inputs, kernel) + bias)`
- by assuming the `kernel` and/or the `bias` are random variables.
-
- The layer implements a stochastic dense calculation by making a Monte Carlo
- approximation of a [variational Bayesian method based on KL divergence](
- https://en.wikipedia.org/wiki/Variational_Bayesian_methods), i.e.,
+ This layer implements the Bayesian variational inference analogue to
+ a dense layer by assuming the `kernel` and/or the `bias` are drawn
+ from distributions. By default, the layer implements a stochastic
+ forward pass via sampling from the kernel and bias posteriors,
```none
- -log p(y|x) = -log int_{R**d} p(y|x,w) p(w) dw
- = -log int_{R**d} p(y,w|x) q(w|x) / q(w|x) dw
- <= E_q(W|x)[-log p(y,W|x) + log q(W|x)] # Jensen's
- = E_q(W|x)[-log p(y|x,W)] + KL[q(W|x), p(W)]
- ~= m**-1 sum{ -log(y|x,w[j]) : w[j] ~ q(W|x), j=1..m }
- + KL[q(W|x), p(W)]
+ kernel, bias ~ posterior
+ outputs = activation(matmul(inputs, kernel) + bias)
```
- where `W` denotes the (independent) `kernel` and `bias` random variables, `w`
- is a random variate or outcome of `W`, `y` is the label, `x` is the evidence`,
- and `~=` denotes an approximation which becomes exact as `m->inf`. The above
- bound is sometimes referred to as the negative Evidence Lower BOund or
- negative [ELBO](https://arxiv.org/abs/1601.00670). In context of a DNN, this
- layer is appropriate to use when the final loss is a negative log-likelihood.
-
- The Monte-Carlo sum portion is used for the feed-forward calculation of the
- DNN. The KL divergence portion can be added to the final loss via:
- `loss += sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))`.
-
The arguments permit separate specification of the surrogate posterior
(`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
- random variables (which together comprise `W`).
+ distributions.
Args:
units: Integer or Long, dimensionality of the output space.
@@ -285,10 +273,39 @@ class DenseVariational(layers_lib.Layer):
activity_regularizer: Regularizer function for the output.
kernel_use_local_reparameterization: Python `bool` indicating whether
`kernel` calculation should employ the Local Reparameterization Trick.
- kernel: `VariationalKernelParamater` instance containing all `kernel`
- related properties and `callable`s.
- bias: `VariationalParameter` instance containing all `kernel`
- related properties and `callable`s.
+ kernel_posterior_fn: `callable` returning posterior.
+ kernel_posterior_tensor_fn: `callable` operating on posterior.
+ kernel_prior_fn: `callable` returning prior.
+ kernel_divergence_fn: `callable` returning divergence.
+ bias_posterior_fn: `callable` returning posterior.
+ bias_posterior_tensor_fn: `callable` operating on posterior.
+ bias_prior_fn: `callable` returning prior.
+ bias_divergence_fn: `callable` returning divergence.
+
+ #### Examples
+
+ We illustrate a Bayesian neural network with [variational inference](
+ https://en.wikipedia.org/wiki/Variational_Bayesian_methods),
+ assuming a dataset of `features` and `labels`.
+
+ ```python
+ tfp = tf.contrib.bayesflow
+
+ net = tfp.layers.DenseVariational(512, activation=tf.nn.relu)(features)
+ logits = tfp.layers.DenseVariational(10)(net)
+ neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits(
+ labels=labels, logits=logits)
+ kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
+ loss = neg_log_likelihood + kl
+ train_op = tf.train.AdamOptimizer().minimize(loss)
+ ```
+
+ It uses reparameterization gradients to minimize the
+ Kullback-Leibler divergence up to a constant, also known as the
+ negative Evidence Lower Bound. It consists of the sum of two terms:
+ the expected negative log-likelihood, which we approximate via
+ Monte Carlo; and the KL divergence, which is added via regularizer
+ terms which are arguments to the layer.
"""
def __init__(
@@ -314,49 +331,19 @@ class DenseVariational(layers_lib.Layer):
name=name,
activity_regularizer=activity_regularizer,
**kwargs)
- self._units = units
- self._activation = activation
- self._input_spec = layers_lib.InputSpec(min_ndim=2)
- self._kernel_use_local_reparameterization = (
+ self.units = units
+ self.activation = activation
+ self.input_spec = layers_lib.InputSpec(min_ndim=2)
+ self.kernel_use_local_reparameterization = (
kernel_use_local_reparameterization)
- self._kernel = VariationalKernelParameter(
- kernel_posterior_fn,
- kernel_posterior_tensor_fn,
- kernel_prior_fn,
- kernel_divergence_fn)
- self._bias = VariationalParameter(
- bias_posterior_fn,
- bias_posterior_tensor_fn,
- bias_prior_fn,
- bias_divergence_fn)
-
- @property
- def units(self):
- return self._units
-
- @property
- def activation(self):
- return self._activation
-
- @property
- def input_spec(self):
- return self._input_spec
-
- @input_spec.setter
- def input_spec(self, value):
- self._input_spec = value
-
- @property
- def kernel_use_local_reparameterization(self):
- return self._kernel_use_local_reparameterization
-
- @property
- def kernel(self):
- return self._kernel
-
- @property
- def bias(self):
- return self._bias
+ self.kernel_posterior_fn = kernel_posterior_fn
+ self.kernel_posterior_tensor_fn = kernel_posterior_tensor_fn
+ self.kernel_prior_fn = kernel_prior_fn
+ self.kernel_divergence_fn = kernel_divergence_fn
+ self.bias_posterior_fn = bias_posterior_fn
+ self.bias_posterior_tensor_fn = bias_posterior_tensor_fn
+ self.bias_prior_fn = bias_prior_fn
+ self.bias_divergence_fn = bias_divergence_fn
def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape)
@@ -368,29 +355,29 @@ class DenseVariational(layers_lib.Layer):
dtype = dtypes.as_dtype(self.dtype)
# Must have a posterior kernel.
- self.kernel.posterior = self.kernel.posterior_fn(
+ self.kernel_posterior = self.kernel_posterior_fn(
dtype, [in_size, self.units], "kernel_posterior",
self.trainable, self.add_variable)
- if self.kernel.prior_fn is None:
+ if self.kernel_prior_fn is None:
self.kernel_prior = None
else:
- self.kernel.prior = self.kernel.prior_fn(
+ self.kernel_prior = self.kernel_prior_fn(
dtype, [in_size, self.units], "kernel_prior",
self.trainable, self.add_variable)
self._built_kernel_divergence = False
- if self.bias.posterior_fn is None:
- self.bias.posterior = None
+ if self.bias_posterior_fn is None:
+ self.bias_posterior = None
else:
- self.bias.posterior = self.bias.posterior_fn(
+ self.bias_posterior = self.bias_posterior_fn(
dtype, [self.units], "bias_posterior",
self.trainable, self.add_variable)
- if self.bias.prior_fn is None:
- self.bias.prior = None
+ if self.bias_prior_fn is None:
+ self.bias_prior = None
else:
- self.bias.prior = self.bias.prior_fn(
+ self.bias_prior = self.bias_prior_fn(
dtype, [self.units], "bias_prior",
self.trainable, self.add_variable)
self._built_bias_divergence = False
@@ -405,54 +392,77 @@ class DenseVariational(layers_lib.Layer):
if self.activation is not None:
outputs = self.activation(outputs) # pylint: disable=not-callable
if not self._built_kernel_divergence:
- self._apply_divergence(self.kernel, name="divergence_kernel")
+ kernel_posterior = self.kernel_posterior
+ kernel_prior = self.kernel_prior
+ if isinstance(self.kernel_posterior, independent_lib.Independent):
+ kernel_posterior = kernel_posterior.distribution
+ if isinstance(self.kernel_prior, independent_lib.Independent):
+ kernel_prior = kernel_prior.distribution
+ self._apply_divergence(self.kernel_divergence_fn,
+ kernel_posterior,
+ kernel_prior,
+ self.kernel_posterior_tensor,
+ name="divergence_kernel")
self._built_kernel_divergence = True
if not self._built_bias_divergence:
- self._apply_divergence(self.bias, name="divergence_bias")
+ bias_posterior = self.bias_posterior
+ bias_prior = self.bias_prior
+ if isinstance(self.bias_posterior, independent_lib.Independent):
+ bias_posterior = bias_posterior.distribution
+ if isinstance(self.bias_prior, independent_lib.Independent):
+ bias_prior = bias_prior.distribution
+ self._apply_divergence(self.bias_divergence_fn,
+ bias_posterior,
+ bias_prior,
+ self.bias_posterior_tensor,
+ name="divergence_bias")
self._built_bias_divergence = True
return outputs
def _apply_variational_kernel(self, inputs):
if not self.kernel_use_local_reparameterization:
- self.kernel.posterior_tensor = self.kernel.posterior_tensor_fn(
- self.kernel.posterior)
- self.kernel.posterior_affine = None
- self.kernel.posterior_affine_tensor = None
- return self._matmul(inputs, self.kernel.posterior_tensor)
- if not isinstance(self.kernel.posterior, normal_lib.Normal):
- raise TypeError("`kernel_use_local_reparameterization=True` requires "
- "`kernel_posterior_fn` produce an instance of "
- "`tf.distributions.Normal` (saw: \"{}\").".format(
- type(self.kernel.posterior).__name__))
- self.kernel.posterior_affine = normal_lib.Normal(
- loc=self._matmul(inputs, self.kernel.posterior.loc),
+ self.kernel_posterior_tensor = self.kernel_posterior_tensor_fn(
+ self.kernel_posterior)
+ self.kernel_posterior_affine = None
+ self.kernel_posterior_affine_tensor = None
+ return self._matmul(inputs, self.kernel_posterior_tensor)
+ if (not isinstance(self.kernel_posterior, independent_lib.Independent) or
+ not isinstance(self.kernel_posterior.distribution, normal_lib.Normal)):
+ raise TypeError(
+ "`kernel_use_local_reparameterization=True` requires "
+ "`kernel_posterior_fn` produce an instance of "
+ "`tf.distributions.Independent(tf.distributions.Normal)` "
+ "(saw: \"{}\").".format(type(self.kernel_posterior).__name__))
+ self.kernel_posterior_affine = normal_lib.Normal(
+ loc=self._matmul(inputs, self.kernel_posterior.distribution.loc),
scale=standard_ops.sqrt(self._matmul(
standard_ops.square(inputs),
- standard_ops.square(self.kernel.posterior.scale))))
- self.kernel.posterior_affine_tensor = (
- self.kernel.posterior_tensor_fn(self.kernel.posterior_affine))
- self.kernel.posterior_tensor = None
- return self.kernel.posterior_affine_tensor
+ standard_ops.square(self.kernel_posterior.distribution.scale))))
+ self.kernel_posterior_affine_tensor = (
+ self.kernel_posterior_tensor_fn(self.kernel_posterior_affine))
+ self.kernel_posterior_tensor = None
+ return self.kernel_posterior_affine_tensor
def _apply_variational_bias(self, inputs):
- if self.bias.posterior is None:
- self.bias.posterior_tensor = None
+ if self.bias_posterior is None:
+ self.bias_posterior_tensor = None
return inputs
- self.bias.posterior_tensor = self.bias.posterior_tensor_fn(
- self.bias.posterior)
- return nn.bias_add(inputs, self.bias.posterior_tensor)
-
- def _apply_divergence(self, param, name):
- if (param.divergence_fn is None or
- param.posterior is None or
- param.prior is None):
- param.divergence = None
+ self.bias_posterior_tensor = self.bias_posterior_tensor_fn(
+ self.bias_posterior)
+ return nn.bias_add(inputs, self.bias_posterior_tensor)
+
+ def _apply_divergence(self, divergence_fn, posterior, prior,
+ posterior_tensor, name):
+ if (divergence_fn is None or
+ posterior is None or
+ prior is None):
+ divergence = None
return
- param.divergence = standard_ops.identity(
- param.divergence_fn(
- param.posterior, param.prior, param.posterior_tensor),
+ divergence = standard_ops.identity(
+ divergence_fn(
+ posterior, prior, posterior_tensor),
name=name)
- self.add_loss(param.divergence)
+ self.add_loss(divergence)
def _matmul(self, inputs, kernel):
if inputs.shape.ndims <= 2:
@@ -489,37 +499,19 @@ def dense_variational(
reuse=None):
"""Densely-connected variational layer.
- This layer implements the Bayesian variational inference analogue to:
- `outputs = activation(matmul(inputs, kernel) + bias)`
- by assuming the `kernel` and/or the `bias` are random variables.
-
- The layer implements a stochastic dense calculation by making a Monte Carlo
- approximation of a [variational Bayesian method based on KL divergence](
- https://en.wikipedia.org/wiki/Variational_Bayesian_methods), i.e.,
+ This layer implements the Bayesian variational inference analogue to
+ a dense layer by assuming the `kernel` and/or the `bias` are drawn
+ from distributions. By default, the layer implements a stochastic
+ forward pass via sampling from the kernel and bias posteriors,
```none
- -log p(y|x) = -log int_{R**d} p(y|x,w) p(w) dw
- = -log int_{R**d} p(y,w|x) q(w|x) / q(w|x) dw
- <= E_q(W|x)[-log p(y,W|x) + log q(W|x)] # Jensen's
- = E_q(W|x)[-log p(y|x,W)] + KL[q(W|x), p(W)]
- ~= m**-1 sum{ -log(y|x,w[j]) : w[j] ~ q(W|x), j=1..m }
- + KL[q(W|x), p(W)]
+ kernel, bias ~ posterior
+ outputs = activation(matmul(inputs, kernel) + bias)
```
- where `W` denotes the (independent) `kernel` and `bias` random variables, `w`
- is a random variate or outcome of `W`, `y` is the label, `x` is the evidence`,
- and `~=` denotes an approximation which becomes exact as `m->inf`. The above
- bound is sometimes referred to as the negative Evidence Lower BOund or
- negative [ELBO](https://arxiv.org/abs/1601.00670). In context of a DNN, this
- layer is appropriate to use when the final loss is a negative log-likelihood.
-
- The Monte-Carlo sum portion is used for the feed-forward calculation of the
- DNN. The KL divergence portion can be added to the final loss via:
- `loss += sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))`.
-
The arguments permit separate specification of the surrogate posterior
(`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
- random variables (which together comprise `W`).
+ distributions.
Args:
inputs: Tensor input.
@@ -574,6 +566,31 @@ def dense_variational(
Returns:
output: `Tensor` representing a the affine transformed input under a random
draw from the surrogate posterior distribution.
+
+ #### Examples
+
+ We illustrate a Bayesian neural network with [variational inference](
+ https://en.wikipedia.org/wiki/Variational_Bayesian_methods),
+ assuming a dataset of `features` and `labels`.
+
+ ```python
+ tfp = tf.contrib.bayesflow
+
+ net = tfp.layers.dense_variational(features, 512, activation=tf.nn.relu)
+ logits = tfp.layers.dense_variational(net, 10)
+ neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits(
+ labels=labels, logits=logits)
+ kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
+ loss = neg_log_likelihood + kl
+ train_op = tf.train.AdamOptimizer().minimize(loss)
+ ```
+
+ It uses reparameterization gradients to minimize the
+ Kullback-Leibler divergence up to a constant, also known as the
+ negative Evidence Lower Bound. It consists of the sum of two terms:
+ the expected negative log-likelihood, which we approximate via
+ Monte Carlo; and the KL divergence, which is added via regularizer
+ terms which are arguments to the layer.
"""
layer = DenseVariational(
units,
@@ -595,203 +612,3 @@ def dense_variational(
_scope=name,
_reuse=reuse)
return layer.apply(inputs)
-
-
-class NotSet(object):
- """Helper to track whether a `VariationalParameter` value has been set."""
- pass
-
-
-class VariationalParameter(object):
- """Struct-like container of variational parameter properties.
-
- A `VariationalParameter` is intitialized with Python `callable`s which set the
- value of correspondingly named members. Corresponding values have "set once"
- semantics, i.e., once set to any value they are immutable.
- """
-
- def __init__(
- self,
- posterior_fn,
- posterior_tensor_fn,
- prior_fn,
- divergence_fn):
- """Creates the `VariationalParameter` struct-like object.
-
- Args:
- posterior_fn: Python `callable` which creates a
- `tf.distribution.Distribution` like object representing the posterior
- distribution. See `VariationalParameter.posterior_fn` for `callable`'s
- required parameters.
- posterior_tensor_fn: Python `callable` which computes a `Tensor`
- which represents the `posterior`.
- prior_fn: Python `callable` which creates a
- `tf.distribution.Distribution` like object representing the prior
- distribution. See `VariationalParameter.prior_fn` for `callable`'s
- required parameters.
- divergence_fn: Python `callable` which computes the KL divergence from
- `posterior` to `prior`. See `VariationalParameter.divergence_fn` for
- required `callable`'s parameters.
- """
- self._posterior_fn = posterior_fn
- self._posterior = NotSet()
- self._posterior_tensor_fn = posterior_tensor_fn
- self._posterior_tensor = NotSet()
- self._prior_fn = prior_fn
- self._prior = NotSet()
- self._divergence_fn = divergence_fn
- self._divergence = NotSet()
- self._init_helper()
-
- @property
- def posterior_fn(self):
- """`callable` which creates `tf.distributions.Distribution`-like posterior.
-
- The `callable` must accept the following parameters:
- name: Python `str` name prepended to any created (or existing)
- `tf.Variable`s.
- shape: Python `list`-like representing the parameter's event shape.
- dtype: Type of parameter's event.
- trainable: Python `bool` indicating all created `tf.Variable`s should be
- added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`.
- add_variable_fn: `tf.get_variable`-like `callable` used to create (or
- access existing) `tf.Variable`s.
-
- Returns:
- posterior_fn: The Python `callable` specified in `__init__`.
- """
- return self._posterior_fn
-
- @property
- def posterior(self):
- """`tf.distributions.Distribution`-like instance representing posterior."""
- return self._posterior
-
- @posterior.setter
- def posterior(self, value):
- """One-time setter of the `posterior` distribution."""
- if not isinstance(self._posterior, NotSet):
- raise ValueError("Cannot override already set attribute.")
- self._posterior = value
-
- @property
- def posterior_tensor_fn(self):
- """Creates `Tensor` representing the `posterior` distribution.
-
- The `callable` must accept the following parameters:
- posterior: `tf.distributions.Distribution`-like instance.
-
- Returns:
- posterior_tensor_fn: The Python `callable` specified in
- `__init__`.
- """
- return self._posterior_tensor_fn
-
- @property
- def posterior_tensor(self):
- """`Tensor` representing the `posterior` distribution."""
- return self._posterior_tensor
-
- @posterior_tensor.setter
- def posterior_tensor(self, value):
- """One-time setter of the `posterior_tensor`."""
- if not isinstance(self._posterior_tensor, NotSet):
- raise ValueError("Cannot override already set attribute.")
- self._posterior_tensor = value
-
- @property
- def prior_fn(self):
- """`callable` which creates `tf.distributions.Distribution`-like prior.
-
- The `callable` must accept the following parameters:
- name: Python `str` name prepended to any created (or existing)
- `tf.Variable`s.
- shape: Python `list`-like representing the parameter's event shape.
- dtype: Type of parameter's event.
- trainable: Python `bool` indicating all created `tf.Variable`s should be
- added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`.
- add_variable_fn: `tf.get_variable`-like `callable` used to create (or
- access existing) `tf.Variable`s.
-
- Returns:
- prior_fn: The Python `callable` specified in `__init__`.
- """
- return self._prior_fn
-
- @property
- def prior(self):
- """`tf.distributions.Distribution`-like instance representing posterior."""
- return self._prior
-
- @prior.setter
- def prior(self, value):
- """One-time setter of the `prior` distribution."""
- if not isinstance(self._prior, NotSet):
- raise ValueError("Cannot override already set attribute.")
- self._prior = value
-
- @property
- def divergence_fn(self):
- """`callable` which computes KL-divergence `Tensor` from posterior to prior.
-
- The `callable` must accept the following parameters:
- posterior: `tf.distributions.Distribution`-like instance.
- prior: `tf.distributions.Distribution`-like instance.
- posterior_tensor: `Tensor` representing value of posterior.
-
- Returns:
- divergence_fn: The Python `callable` specified in `__init__`.
- """
- return self._divergence_fn
-
- @property
- def divergence(self):
- """`Tensor` representing KL-divergence from posterior to prior."""
- return self._divergence
-
- @divergence.setter
- def divergence(self, value):
- """One-time setter of the `divergence`."""
- if not isinstance(self._divergence, NotSet):
- raise ValueError("Cannot override already set attribute.")
- self._divergence = value
-
- def _init_helper(self):
- pass
-
-
-class VariationalKernelParameter(VariationalParameter):
- """Struct-like container of variational kernel properties.
-
- A `VariationalKernelParameter` is intitialized with Python `callable`s which
- set the value of correspondingly named members. Corresponding values have "set
- once" semantics, i.e., once set to any value they are immutable.
- """
-
- @property
- def posterior_affine(self):
- """`tf.distributions.Distribution` affine transformed posterior."""
- return self._posterior_affine
-
- @posterior_affine.setter
- def posterior_affine(self, value):
- """One-time setter of `posterior_affine`."""
- if not isinstance(self._posterior_affine, NotSet):
- raise ValueError("Cannot override already set attribute.")
- self._posterior_affine = value
-
- @property
- def posterior_affine_tensor(self):
- """`Tensor` representing the `posterior_affine` distribution."""
- return self._posterior_affine_tensor
-
- @posterior_affine_tensor.setter
- def posterior_affine_tensor(self, value):
- """One-time setter of the `posterior_affine_tensor`."""
- if not isinstance(self._posterior_affine_tensor, NotSet):
- raise ValueError("Cannot override already set attribute.")
- self._posterior_affine_tensor = value
-
- def _init_helper(self):
- self._posterior_affine = NotSet()
- self._posterior_affine_tensor = NotSet()