diff options
author | Dustin Tran <trandustin@google.com> | 2017-12-15 16:38:28 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-12-15 16:41:59 -0800 |
commit | 4f4abcacedcba5430e03320f39205d2f327df2ac (patch) | |
tree | f8f8b4cef612436b342b1db08799adf6f41dade6 /tensorflow/contrib/bayesflow | |
parent | f3df9fcaefeb3ab0fd83f255bec93e1a3c013a5e (diff) |
Restandardize `DenseVariational` as simpler template for other probabilistic layers.
PiperOrigin-RevId: 179255435
Diffstat (limited to 'tensorflow/contrib/bayesflow')
-rw-r--r-- | tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py | 55 | ||||
-rw-r--r-- | tensorflow/contrib/bayesflow/python/ops/layers_dense_variational_impl.py | 501 |
2 files changed, 194 insertions, 362 deletions
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py index 50358fd1c2..7b5b2fec1e 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.bayesflow.python.ops import layers_dense_variational_impl as prob_layers_lib +from tensorflow.contrib.distributions.python.ops import independent as independent_lib from tensorflow.python.framework import ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops @@ -41,7 +42,7 @@ class Counter(object): return self._value -class MockDistribution(normal_lib.Normal): +class MockDistribution(independent_lib.Independent): """Monitors DenseVariational calls to the underlying distribution.""" def __init__(self, result_sample, result_log_prob, loc=None, scale=None): @@ -49,6 +50,10 @@ class MockDistribution(normal_lib.Normal): self.result_log_prob = result_log_prob self.result_loc = loc self.result_scale = scale + self.result_distribution = normal_lib.Normal(loc=0.0, scale=1.0) + if loc is not None and scale is not None: + self.result_distribution = normal_lib.Normal(loc=self.result_loc, + scale=self.result_scale) self.called_log_prob = Counter() self.called_sample = Counter() self.called_loc = Counter() @@ -63,6 +68,10 @@ class MockDistribution(normal_lib.Normal): return self.result_sample @property + def distribution(self): # for dummy check on Independent(Normal) + return self.result_distribution + + @property def loc(self): self.called_loc() return self.result_loc @@ -95,16 +104,16 @@ class DenseVariationalLocalReparametrization(test.TestCase): inputs = random_ops.random_uniform([2, 3], seed=1) # No keys. - loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - self.assertEqual(len(loss_keys), 0) - self.assertListEqual(dense_vi.losses, loss_keys) + losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(losses), 0) + self.assertListEqual(dense_vi.losses, losses) _ = dense_vi(inputs) # Yes keys. - loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - self.assertEqual(len(loss_keys), 1) - self.assertListEqual(dense_vi.losses, loss_keys) + losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(losses), 1) + self.assertListEqual(dense_vi.losses, losses) def testKLPenaltyBoth(self): def _make_normal(dtype, *args): # pylint: disable=unused-argument @@ -118,16 +127,16 @@ class DenseVariationalLocalReparametrization(test.TestCase): inputs = random_ops.random_uniform([2, 3], seed=1) # No keys. - loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - self.assertEqual(len(loss_keys), 0) - self.assertListEqual(dense_vi.losses, loss_keys) + losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(losses), 0) + self.assertListEqual(dense_vi.losses, losses) _ = dense_vi(inputs) # Yes keys. - loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - self.assertEqual(len(loss_keys), 2) - self.assertListEqual(dense_vi.losses, loss_keys) + losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(losses), 2) + self.assertListEqual(dense_vi.losses, losses) def testVariationalNonLocal(self): batch_size, in_size, out_size = 2, 3, 4 @@ -183,9 +192,9 @@ class DenseVariationalLocalReparametrization(test.TestCase): expected_bias_divergence_, actual_bias_divergence_, ] = sess.run([ expected_outputs, outputs, - kernel_posterior.result_sample, dense_vi.kernel.posterior_tensor, + kernel_posterior.result_sample, dense_vi.kernel_posterior_tensor, kernel_divergence.result, kl_penalty[0], - bias_posterior.result_sample, dense_vi.bias.posterior_tensor, + bias_posterior.result_sample, dense_vi.bias_posterior_tensor, bias_divergence.result, kl_penalty[1], ]) @@ -206,11 +215,15 @@ class DenseVariationalLocalReparametrization(test.TestCase): rtol=1e-6, atol=0.) self.assertAllEqual( - [[kernel_posterior, kernel_prior, kernel_posterior.result_sample]], + [[kernel_posterior.distribution, + kernel_prior.distribution, + kernel_posterior.result_sample]], kernel_divergence.args) self.assertAllEqual( - [[bias_posterior, bias_prior, bias_posterior.result_sample]], + [[bias_posterior.distribution, + bias_prior.distribution, + bias_posterior.result_sample]], bias_divergence.args) def testVariationalLocal(self): @@ -274,7 +287,7 @@ class DenseVariationalLocalReparametrization(test.TestCase): ] = sess.run([ expected_outputs, outputs, kernel_divergence.result, kl_penalty[0], - bias_posterior.result_sample, dense_vi.bias.posterior_tensor, + bias_posterior.result_sample, dense_vi.bias_posterior_tensor, bias_divergence.result, kl_penalty[1], ]) @@ -292,11 +305,13 @@ class DenseVariationalLocalReparametrization(test.TestCase): rtol=1e-6, atol=0.) self.assertAllEqual( - [[kernel_posterior, kernel_prior, None]], + [[kernel_posterior.distribution, kernel_prior.distribution, None]], kernel_divergence.args) self.assertAllEqual( - [[bias_posterior, bias_prior, bias_posterior.result_sample]], + [[bias_posterior.distribution, + bias_prior.distribution, + bias_posterior.result_sample]], bias_divergence.args) diff --git a/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational_impl.py b/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational_impl.py index b05ce0ffc1..a3b22f334a 100644 --- a/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational_impl.py +++ b/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational_impl.py @@ -28,10 +28,12 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.distributions.python.ops import deterministic as deterministic_lib +from tensorflow.contrib.distributions.python.ops import independent as independent_lib from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.layers import base as layers_lib +from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import nn from tensorflow.python.ops import nn_ops @@ -187,48 +189,34 @@ def default_mean_field_normal_fn( loc_constraint, untransformed_scale_constraint) def _fn(dtype, shape, name, trainable, add_variable_fn): - """Creates a batch of `Deterministic` or `Normal` distributions.""" + """Creates multivariate `Deterministic` or `Normal` distribution.""" loc, scale = loc_scale_fn_(dtype, shape, name, trainable, add_variable_fn) if scale is None: - return deterministic_lib.Deterministic(loc=loc) - return normal_lib.Normal(loc=loc, scale=scale) + dist = deterministic_lib.Deterministic(loc=loc) + else: + dist = normal_lib.Normal(loc=loc, scale=scale) + reinterpreted_batch_ndims = array_ops.shape(dist.batch_shape_tensor())[0] + return independent_lib.Independent( + dist, reinterpreted_batch_ndims=reinterpreted_batch_ndims) return _fn class DenseVariational(layers_lib.Layer): """Densely-connected variational class. - This layer implements the Bayesian variational inference analogue to: - `outputs = activation(matmul(inputs, kernel) + bias)` - by assuming the `kernel` and/or the `bias` are random variables. - - The layer implements a stochastic dense calculation by making a Monte Carlo - approximation of a [variational Bayesian method based on KL divergence]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), i.e., + This layer implements the Bayesian variational inference analogue to + a dense layer by assuming the `kernel` and/or the `bias` are drawn + from distributions. By default, the layer implements a stochastic + forward pass via sampling from the kernel and bias posteriors, ```none - -log p(y|x) = -log int_{R**d} p(y|x,w) p(w) dw - = -log int_{R**d} p(y,w|x) q(w|x) / q(w|x) dw - <= E_q(W|x)[-log p(y,W|x) + log q(W|x)] # Jensen's - = E_q(W|x)[-log p(y|x,W)] + KL[q(W|x), p(W)] - ~= m**-1 sum{ -log(y|x,w[j]) : w[j] ~ q(W|x), j=1..m } - + KL[q(W|x), p(W)] + kernel, bias ~ posterior + outputs = activation(matmul(inputs, kernel) + bias) ``` - where `W` denotes the (independent) `kernel` and `bias` random variables, `w` - is a random variate or outcome of `W`, `y` is the label, `x` is the evidence`, - and `~=` denotes an approximation which becomes exact as `m->inf`. The above - bound is sometimes referred to as the negative Evidence Lower BOund or - negative [ELBO](https://arxiv.org/abs/1601.00670). In context of a DNN, this - layer is appropriate to use when the final loss is a negative log-likelihood. - - The Monte-Carlo sum portion is used for the feed-forward calculation of the - DNN. The KL divergence portion can be added to the final loss via: - `loss += sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))`. - The arguments permit separate specification of the surrogate posterior (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - random variables (which together comprise `W`). + distributions. Args: units: Integer or Long, dimensionality of the output space. @@ -285,10 +273,39 @@ class DenseVariational(layers_lib.Layer): activity_regularizer: Regularizer function for the output. kernel_use_local_reparameterization: Python `bool` indicating whether `kernel` calculation should employ the Local Reparameterization Trick. - kernel: `VariationalKernelParamater` instance containing all `kernel` - related properties and `callable`s. - bias: `VariationalParameter` instance containing all `kernel` - related properties and `callable`s. + kernel_posterior_fn: `callable` returning posterior. + kernel_posterior_tensor_fn: `callable` operating on posterior. + kernel_prior_fn: `callable` returning prior. + kernel_divergence_fn: `callable` returning divergence. + bias_posterior_fn: `callable` returning posterior. + bias_posterior_tensor_fn: `callable` operating on posterior. + bias_prior_fn: `callable` returning prior. + bias_divergence_fn: `callable` returning divergence. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tfp.layers.DenseVariational(512, activation=tf.nn.relu)(features) + logits = tfp.layers.DenseVariational(10)(net) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses reparameterization gradients to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. """ def __init__( @@ -314,49 +331,19 @@ class DenseVariational(layers_lib.Layer): name=name, activity_regularizer=activity_regularizer, **kwargs) - self._units = units - self._activation = activation - self._input_spec = layers_lib.InputSpec(min_ndim=2) - self._kernel_use_local_reparameterization = ( + self.units = units + self.activation = activation + self.input_spec = layers_lib.InputSpec(min_ndim=2) + self.kernel_use_local_reparameterization = ( kernel_use_local_reparameterization) - self._kernel = VariationalKernelParameter( - kernel_posterior_fn, - kernel_posterior_tensor_fn, - kernel_prior_fn, - kernel_divergence_fn) - self._bias = VariationalParameter( - bias_posterior_fn, - bias_posterior_tensor_fn, - bias_prior_fn, - bias_divergence_fn) - - @property - def units(self): - return self._units - - @property - def activation(self): - return self._activation - - @property - def input_spec(self): - return self._input_spec - - @input_spec.setter - def input_spec(self, value): - self._input_spec = value - - @property - def kernel_use_local_reparameterization(self): - return self._kernel_use_local_reparameterization - - @property - def kernel(self): - return self._kernel - - @property - def bias(self): - return self._bias + self.kernel_posterior_fn = kernel_posterior_fn + self.kernel_posterior_tensor_fn = kernel_posterior_tensor_fn + self.kernel_prior_fn = kernel_prior_fn + self.kernel_divergence_fn = kernel_divergence_fn + self.bias_posterior_fn = bias_posterior_fn + self.bias_posterior_tensor_fn = bias_posterior_tensor_fn + self.bias_prior_fn = bias_prior_fn + self.bias_divergence_fn = bias_divergence_fn def build(self, input_shape): input_shape = tensor_shape.TensorShape(input_shape) @@ -368,29 +355,29 @@ class DenseVariational(layers_lib.Layer): dtype = dtypes.as_dtype(self.dtype) # Must have a posterior kernel. - self.kernel.posterior = self.kernel.posterior_fn( + self.kernel_posterior = self.kernel_posterior_fn( dtype, [in_size, self.units], "kernel_posterior", self.trainable, self.add_variable) - if self.kernel.prior_fn is None: + if self.kernel_prior_fn is None: self.kernel_prior = None else: - self.kernel.prior = self.kernel.prior_fn( + self.kernel_prior = self.kernel_prior_fn( dtype, [in_size, self.units], "kernel_prior", self.trainable, self.add_variable) self._built_kernel_divergence = False - if self.bias.posterior_fn is None: - self.bias.posterior = None + if self.bias_posterior_fn is None: + self.bias_posterior = None else: - self.bias.posterior = self.bias.posterior_fn( + self.bias_posterior = self.bias_posterior_fn( dtype, [self.units], "bias_posterior", self.trainable, self.add_variable) - if self.bias.prior_fn is None: - self.bias.prior = None + if self.bias_prior_fn is None: + self.bias_prior = None else: - self.bias.prior = self.bias.prior_fn( + self.bias_prior = self.bias_prior_fn( dtype, [self.units], "bias_prior", self.trainable, self.add_variable) self._built_bias_divergence = False @@ -405,54 +392,77 @@ class DenseVariational(layers_lib.Layer): if self.activation is not None: outputs = self.activation(outputs) # pylint: disable=not-callable if not self._built_kernel_divergence: - self._apply_divergence(self.kernel, name="divergence_kernel") + kernel_posterior = self.kernel_posterior + kernel_prior = self.kernel_prior + if isinstance(self.kernel_posterior, independent_lib.Independent): + kernel_posterior = kernel_posterior.distribution + if isinstance(self.kernel_prior, independent_lib.Independent): + kernel_prior = kernel_prior.distribution + self._apply_divergence(self.kernel_divergence_fn, + kernel_posterior, + kernel_prior, + self.kernel_posterior_tensor, + name="divergence_kernel") self._built_kernel_divergence = True if not self._built_bias_divergence: - self._apply_divergence(self.bias, name="divergence_bias") + bias_posterior = self.bias_posterior + bias_prior = self.bias_prior + if isinstance(self.bias_posterior, independent_lib.Independent): + bias_posterior = bias_posterior.distribution + if isinstance(self.bias_prior, independent_lib.Independent): + bias_prior = bias_prior.distribution + self._apply_divergence(self.bias_divergence_fn, + bias_posterior, + bias_prior, + self.bias_posterior_tensor, + name="divergence_bias") self._built_bias_divergence = True return outputs def _apply_variational_kernel(self, inputs): if not self.kernel_use_local_reparameterization: - self.kernel.posterior_tensor = self.kernel.posterior_tensor_fn( - self.kernel.posterior) - self.kernel.posterior_affine = None - self.kernel.posterior_affine_tensor = None - return self._matmul(inputs, self.kernel.posterior_tensor) - if not isinstance(self.kernel.posterior, normal_lib.Normal): - raise TypeError("`kernel_use_local_reparameterization=True` requires " - "`kernel_posterior_fn` produce an instance of " - "`tf.distributions.Normal` (saw: \"{}\").".format( - type(self.kernel.posterior).__name__)) - self.kernel.posterior_affine = normal_lib.Normal( - loc=self._matmul(inputs, self.kernel.posterior.loc), + self.kernel_posterior_tensor = self.kernel_posterior_tensor_fn( + self.kernel_posterior) + self.kernel_posterior_affine = None + self.kernel_posterior_affine_tensor = None + return self._matmul(inputs, self.kernel_posterior_tensor) + if (not isinstance(self.kernel_posterior, independent_lib.Independent) or + not isinstance(self.kernel_posterior.distribution, normal_lib.Normal)): + raise TypeError( + "`kernel_use_local_reparameterization=True` requires " + "`kernel_posterior_fn` produce an instance of " + "`tf.distributions.Independent(tf.distributions.Normal)` " + "(saw: \"{}\").".format(type(self.kernel_posterior).__name__)) + self.kernel_posterior_affine = normal_lib.Normal( + loc=self._matmul(inputs, self.kernel_posterior.distribution.loc), scale=standard_ops.sqrt(self._matmul( standard_ops.square(inputs), - standard_ops.square(self.kernel.posterior.scale)))) - self.kernel.posterior_affine_tensor = ( - self.kernel.posterior_tensor_fn(self.kernel.posterior_affine)) - self.kernel.posterior_tensor = None - return self.kernel.posterior_affine_tensor + standard_ops.square(self.kernel_posterior.distribution.scale)))) + self.kernel_posterior_affine_tensor = ( + self.kernel_posterior_tensor_fn(self.kernel_posterior_affine)) + self.kernel_posterior_tensor = None + return self.kernel_posterior_affine_tensor def _apply_variational_bias(self, inputs): - if self.bias.posterior is None: - self.bias.posterior_tensor = None + if self.bias_posterior is None: + self.bias_posterior_tensor = None return inputs - self.bias.posterior_tensor = self.bias.posterior_tensor_fn( - self.bias.posterior) - return nn.bias_add(inputs, self.bias.posterior_tensor) - - def _apply_divergence(self, param, name): - if (param.divergence_fn is None or - param.posterior is None or - param.prior is None): - param.divergence = None + self.bias_posterior_tensor = self.bias_posterior_tensor_fn( + self.bias_posterior) + return nn.bias_add(inputs, self.bias_posterior_tensor) + + def _apply_divergence(self, divergence_fn, posterior, prior, + posterior_tensor, name): + if (divergence_fn is None or + posterior is None or + prior is None): + divergence = None return - param.divergence = standard_ops.identity( - param.divergence_fn( - param.posterior, param.prior, param.posterior_tensor), + divergence = standard_ops.identity( + divergence_fn( + posterior, prior, posterior_tensor), name=name) - self.add_loss(param.divergence) + self.add_loss(divergence) def _matmul(self, inputs, kernel): if inputs.shape.ndims <= 2: @@ -489,37 +499,19 @@ def dense_variational( reuse=None): """Densely-connected variational layer. - This layer implements the Bayesian variational inference analogue to: - `outputs = activation(matmul(inputs, kernel) + bias)` - by assuming the `kernel` and/or the `bias` are random variables. - - The layer implements a stochastic dense calculation by making a Monte Carlo - approximation of a [variational Bayesian method based on KL divergence]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), i.e., + This layer implements the Bayesian variational inference analogue to + a dense layer by assuming the `kernel` and/or the `bias` are drawn + from distributions. By default, the layer implements a stochastic + forward pass via sampling from the kernel and bias posteriors, ```none - -log p(y|x) = -log int_{R**d} p(y|x,w) p(w) dw - = -log int_{R**d} p(y,w|x) q(w|x) / q(w|x) dw - <= E_q(W|x)[-log p(y,W|x) + log q(W|x)] # Jensen's - = E_q(W|x)[-log p(y|x,W)] + KL[q(W|x), p(W)] - ~= m**-1 sum{ -log(y|x,w[j]) : w[j] ~ q(W|x), j=1..m } - + KL[q(W|x), p(W)] + kernel, bias ~ posterior + outputs = activation(matmul(inputs, kernel) + bias) ``` - where `W` denotes the (independent) `kernel` and `bias` random variables, `w` - is a random variate or outcome of `W`, `y` is the label, `x` is the evidence`, - and `~=` denotes an approximation which becomes exact as `m->inf`. The above - bound is sometimes referred to as the negative Evidence Lower BOund or - negative [ELBO](https://arxiv.org/abs/1601.00670). In context of a DNN, this - layer is appropriate to use when the final loss is a negative log-likelihood. - - The Monte-Carlo sum portion is used for the feed-forward calculation of the - DNN. The KL divergence portion can be added to the final loss via: - `loss += sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))`. - The arguments permit separate specification of the surrogate posterior (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - random variables (which together comprise `W`). + distributions. Args: inputs: Tensor input. @@ -574,6 +566,31 @@ def dense_variational( Returns: output: `Tensor` representing a the affine transformed input under a random draw from the surrogate posterior distribution. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tfp.layers.dense_variational(features, 512, activation=tf.nn.relu) + logits = tfp.layers.dense_variational(net, 10) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses reparameterization gradients to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. """ layer = DenseVariational( units, @@ -595,203 +612,3 @@ def dense_variational( _scope=name, _reuse=reuse) return layer.apply(inputs) - - -class NotSet(object): - """Helper to track whether a `VariationalParameter` value has been set.""" - pass - - -class VariationalParameter(object): - """Struct-like container of variational parameter properties. - - A `VariationalParameter` is intitialized with Python `callable`s which set the - value of correspondingly named members. Corresponding values have "set once" - semantics, i.e., once set to any value they are immutable. - """ - - def __init__( - self, - posterior_fn, - posterior_tensor_fn, - prior_fn, - divergence_fn): - """Creates the `VariationalParameter` struct-like object. - - Args: - posterior_fn: Python `callable` which creates a - `tf.distribution.Distribution` like object representing the posterior - distribution. See `VariationalParameter.posterior_fn` for `callable`'s - required parameters. - posterior_tensor_fn: Python `callable` which computes a `Tensor` - which represents the `posterior`. - prior_fn: Python `callable` which creates a - `tf.distribution.Distribution` like object representing the prior - distribution. See `VariationalParameter.prior_fn` for `callable`'s - required parameters. - divergence_fn: Python `callable` which computes the KL divergence from - `posterior` to `prior`. See `VariationalParameter.divergence_fn` for - required `callable`'s parameters. - """ - self._posterior_fn = posterior_fn - self._posterior = NotSet() - self._posterior_tensor_fn = posterior_tensor_fn - self._posterior_tensor = NotSet() - self._prior_fn = prior_fn - self._prior = NotSet() - self._divergence_fn = divergence_fn - self._divergence = NotSet() - self._init_helper() - - @property - def posterior_fn(self): - """`callable` which creates `tf.distributions.Distribution`-like posterior. - - The `callable` must accept the following parameters: - name: Python `str` name prepended to any created (or existing) - `tf.Variable`s. - shape: Python `list`-like representing the parameter's event shape. - dtype: Type of parameter's event. - trainable: Python `bool` indicating all created `tf.Variable`s should be - added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. - add_variable_fn: `tf.get_variable`-like `callable` used to create (or - access existing) `tf.Variable`s. - - Returns: - posterior_fn: The Python `callable` specified in `__init__`. - """ - return self._posterior_fn - - @property - def posterior(self): - """`tf.distributions.Distribution`-like instance representing posterior.""" - return self._posterior - - @posterior.setter - def posterior(self, value): - """One-time setter of the `posterior` distribution.""" - if not isinstance(self._posterior, NotSet): - raise ValueError("Cannot override already set attribute.") - self._posterior = value - - @property - def posterior_tensor_fn(self): - """Creates `Tensor` representing the `posterior` distribution. - - The `callable` must accept the following parameters: - posterior: `tf.distributions.Distribution`-like instance. - - Returns: - posterior_tensor_fn: The Python `callable` specified in - `__init__`. - """ - return self._posterior_tensor_fn - - @property - def posterior_tensor(self): - """`Tensor` representing the `posterior` distribution.""" - return self._posterior_tensor - - @posterior_tensor.setter - def posterior_tensor(self, value): - """One-time setter of the `posterior_tensor`.""" - if not isinstance(self._posterior_tensor, NotSet): - raise ValueError("Cannot override already set attribute.") - self._posterior_tensor = value - - @property - def prior_fn(self): - """`callable` which creates `tf.distributions.Distribution`-like prior. - - The `callable` must accept the following parameters: - name: Python `str` name prepended to any created (or existing) - `tf.Variable`s. - shape: Python `list`-like representing the parameter's event shape. - dtype: Type of parameter's event. - trainable: Python `bool` indicating all created `tf.Variable`s should be - added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. - add_variable_fn: `tf.get_variable`-like `callable` used to create (or - access existing) `tf.Variable`s. - - Returns: - prior_fn: The Python `callable` specified in `__init__`. - """ - return self._prior_fn - - @property - def prior(self): - """`tf.distributions.Distribution`-like instance representing posterior.""" - return self._prior - - @prior.setter - def prior(self, value): - """One-time setter of the `prior` distribution.""" - if not isinstance(self._prior, NotSet): - raise ValueError("Cannot override already set attribute.") - self._prior = value - - @property - def divergence_fn(self): - """`callable` which computes KL-divergence `Tensor` from posterior to prior. - - The `callable` must accept the following parameters: - posterior: `tf.distributions.Distribution`-like instance. - prior: `tf.distributions.Distribution`-like instance. - posterior_tensor: `Tensor` representing value of posterior. - - Returns: - divergence_fn: The Python `callable` specified in `__init__`. - """ - return self._divergence_fn - - @property - def divergence(self): - """`Tensor` representing KL-divergence from posterior to prior.""" - return self._divergence - - @divergence.setter - def divergence(self, value): - """One-time setter of the `divergence`.""" - if not isinstance(self._divergence, NotSet): - raise ValueError("Cannot override already set attribute.") - self._divergence = value - - def _init_helper(self): - pass - - -class VariationalKernelParameter(VariationalParameter): - """Struct-like container of variational kernel properties. - - A `VariationalKernelParameter` is intitialized with Python `callable`s which - set the value of correspondingly named members. Corresponding values have "set - once" semantics, i.e., once set to any value they are immutable. - """ - - @property - def posterior_affine(self): - """`tf.distributions.Distribution` affine transformed posterior.""" - return self._posterior_affine - - @posterior_affine.setter - def posterior_affine(self, value): - """One-time setter of `posterior_affine`.""" - if not isinstance(self._posterior_affine, NotSet): - raise ValueError("Cannot override already set attribute.") - self._posterior_affine = value - - @property - def posterior_affine_tensor(self): - """`Tensor` representing the `posterior_affine` distribution.""" - return self._posterior_affine_tensor - - @posterior_affine_tensor.setter - def posterior_affine_tensor(self, value): - """One-time setter of the `posterior_affine_tensor`.""" - if not isinstance(self._posterior_affine_tensor, NotSet): - raise ValueError("Cannot override already set attribute.") - self._posterior_affine_tensor = value - - def _init_helper(self): - self._posterior_affine = NotSet() - self._posterior_affine_tensor = NotSet() |