aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py')
-rw-r--r--tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py145
1 files changed, 85 insertions, 60 deletions
diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py
index 729d8525fa..14b8324e48 100644
--- a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py
+++ b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py
@@ -32,20 +32,28 @@ from tensorflow.contrib.eager.python.examples.l2hmc import neural_nets
class Dynamics(tf.keras.Model):
- """Dynamics engine of naive L2HMC sampler.
-
- Args:
- x_dim: dimensionality of observed data
- loglikelihood_fn: log-likelihood function of conditional probability
- n_steps: number of leapfrog steps within each transition
- eps: initial value learnable scale of step size
- """
-
- def __init__(self, x_dim, loglikelihood_fn, n_steps=25, eps=.1):
+ """Dynamics engine of naive L2HMC sampler."""
+
+ def __init__(self,
+ x_dim,
+ minus_loglikelihood_fn,
+ n_steps=25,
+ eps=.1,
+ np_seed=1):
+ """Initialization.
+
+ Args:
+ x_dim: dimensionality of observed data
+ minus_loglikelihood_fn: log-likelihood function of conditional probability
+ n_steps: number of leapfrog steps within each transition
+ eps: initial value learnable scale of step size
+ np_seed: Random seed for numpy; used to control sampled masks.
+ """
super(Dynamics, self).__init__()
+ npr.seed(np_seed)
self.x_dim = x_dim
- self.potential = loglikelihood_fn
+ self.potential = minus_loglikelihood_fn
self.n_steps = n_steps
self._construct_time()
@@ -54,7 +62,7 @@ class Dynamics(tf.keras.Model):
self.position_fn = neural_nets.GenericNet(x_dim, factor=2.)
self.momentum_fn = neural_nets.GenericNet(x_dim, factor=1.)
- self.eps = tfe.Variable(
+ self.eps = tf.Variable(
initial_value=eps, name="eps", dtype=tf.float32, trainable=True)
def apply_transition(self, position):
@@ -68,8 +76,8 @@ class Dynamics(tf.keras.Model):
position, forward=False)
# Decide direction uniformly
- forward_mask = tf.cast(
- tf.random_uniform(shape=[tf.shape(position)[0]]) > .5, tf.float32)
+ batch_size = tf.shape(position)[0]
+ forward_mask = tf.cast(tf.random_uniform((batch_size,)) > .5, tf.float32)
backward_mask = 1. - forward_mask
# Obtain proposed states
@@ -108,7 +116,6 @@ class Dynamics(tf.keras.Model):
position_post, momentum_post, logdet = lf_fn(position_post, momentum_post,
i)
sumlogdet += logdet
-
accept_prob = self._compute_accept_prob(position, momentum, position_post,
momentum_post, sumlogdet)
@@ -125,17 +132,17 @@ class Dynamics(tf.keras.Model):
sumlogdet += logdet
position, logdet = self._update_position_forward(position, momentum, t,
- mask)
+ mask, mask_inv)
sumlogdet += logdet
position, logdet = self._update_position_forward(position, momentum, t,
- mask_inv)
+ mask_inv, mask)
sumlogdet += logdet
momentum, logdet = self._update_momentum_forward(position, momentum, t)
sumlogdet += logdet
- return position, momentum, tf.reduce_sum(sumlogdet, axis=1)
+ return position, momentum, sumlogdet
def _backward_lf(self, position, momentum, i):
"""One backward augmented leapfrog step. See Appendix A in paper."""
@@ -149,17 +156,17 @@ class Dynamics(tf.keras.Model):
sumlogdet += logdet
position, logdet = self._update_position_backward(position, momentum, t,
- mask)
+ mask_inv, mask)
sumlogdet += logdet
position, logdet = self._update_position_backward(position, momentum, t,
- mask_inv)
+ mask, mask_inv)
sumlogdet += logdet
momentum, logdet = self._update_momentum_backward(position, momentum, t)
sumlogdet += logdet
- return position, momentum, tf.reduce_sum(sumlogdet, axis=1)
+ return position, momentum, sumlogdet
def _update_momentum_forward(self, position, momentum, t):
"""Update v in the forward leapfrog step."""
@@ -172,12 +179,11 @@ class Dynamics(tf.keras.Model):
momentum * tf.exp(scale) -
.5 * self.eps * (tf.exp(transformed) * grad - translation))
- return momentum, scale
+ return momentum, tf.reduce_sum(scale, axis=1)
- def _update_position_forward(self, position, momentum, t, mask):
+ def _update_position_forward(self, position, momentum, t, mask, mask_inv):
"""Update x in the forward leapfrog step."""
- mask_inv = 1. - mask
scale, translation, transformed = self.position_fn(
[momentum, mask * position, t])
scale *= self.eps
@@ -186,8 +192,7 @@ class Dynamics(tf.keras.Model):
mask * position +
mask_inv * (position * tf.exp(scale) + self.eps *
(tf.exp(transformed) * momentum + translation)))
-
- return position, mask_inv * scale
+ return position, tf.reduce_sum(mask_inv * scale, axis=1)
def _update_momentum_backward(self, position, momentum, t):
"""Update v in the backward leapfrog step. Inverting the forward update."""
@@ -200,21 +205,20 @@ class Dynamics(tf.keras.Model):
tf.exp(scale) * (momentum + .5 * self.eps *
(tf.exp(transformed) * grad - translation)))
- return momentum, scale
+ return momentum, tf.reduce_sum(scale, axis=1)
- def _update_position_backward(self, position, momentum, t, mask):
+ def _update_position_backward(self, position, momentum, t, mask, mask_inv):
"""Update x in the backward leapfrog step. Inverting the forward update."""
- mask_inv = 1. - mask
scale, translation, transformed = self.position_fn(
- [momentum, mask_inv * position, t])
+ [momentum, mask * position, t])
scale *= -self.eps
transformed *= self.eps
position = (
- mask_inv * position + mask * tf.exp(scale) *
- (position - self.eps * tf.exp(transformed) * momentum + translation))
+ mask * position + mask_inv * tf.exp(scale) *
+ (position - self.eps * (tf.exp(transformed) * momentum + translation)))
- return position, mask * scale
+ return position, tf.reduce_sum(mask_inv * scale, axis=1)
def _compute_accept_prob(self, position, momentum, position_post,
momentum_post, sumlogdet):
@@ -222,8 +226,10 @@ class Dynamics(tf.keras.Model):
old_hamil = self.hamiltonian(position, momentum)
new_hamil = self.hamiltonian(position_post, momentum_post)
+ prob = tf.exp(tf.minimum(old_hamil - new_hamil + sumlogdet, 0.))
- return tf.exp(tf.minimum(old_hamil - new_hamil + sumlogdet, 0.))
+ # Ensure numerical stability as well as correct gradients
+ return tf.where(tf.is_finite(prob), prob, tf.zeros_like(prob))
def _construct_time(self):
"""Convert leapfrog step index into sinusoidal time."""
@@ -248,6 +254,8 @@ class Dynamics(tf.keras.Model):
self.masks = []
for _ in range(self.n_steps):
+ # Need to use npr here because tf would generated different random
+ # values across different `sess.run`
idx = npr.permutation(np.arange(self.x_dim))[:self.x_dim // 2]
mask = np.zeros((self.x_dim,))
mask[idx] = 1.
@@ -273,19 +281,15 @@ class Dynamics(tf.keras.Model):
def grad_potential(self, position, check_numerics=True):
"""Get gradient of potential function at current location."""
- if not tf.executing_eagerly():
- # TODO(lxuechen): Change this to tfe.gradients_function when it works
- grad = tf.gradients(self.potential(position), position)[0]
- else:
+ if tf.executing_eagerly():
grad = tfe.gradients_function(self.potential)(position)[0]
-
- if check_numerics:
- return tf.check_numerics(grad, message="gradient of potential")
+ else:
+ grad = tf.gradients(self.potential(position), position)[0]
return grad
-# Examples of unnormalized log density/probabilities
+# Examples of unnormalized log densities
def get_scg_energy_fn():
"""Get energy function for 2d strongly correlated Gaussian."""
@@ -295,32 +299,53 @@ def get_scg_energy_fn():
sigma_inv = tf.matrix_inverse(sigma)
def energy(x):
- """Unnormalized log density/energy of 2d strongly correlated Gaussian."""
+ """Unnormalized minus log density of 2d strongly correlated Gaussian."""
xmmu = x - mu
return .5 * tf.diag_part(
tf.matmul(tf.matmul(xmmu, sigma_inv), tf.transpose(xmmu)))
- return energy
+ return energy, mu, sigma
-def get_multivariate_gaussian_energy_fn(x_dim=2):
- """Get energy function for 2d strongly correlated Gaussian."""
-
- mu = tf.random_normal(shape=[x_dim])
- # Lower triangularize and positive diagonal
- l = tf.sigmoid(
- tf.matrix_band_part(tf.random_normal(shape=[x_dim, x_dim]), -1, 0))
- # Exploit Cholesky decomposition
- sigma = tf.matmul(l, tf.transpose(l))
- sigma *= 100. # Small covariance causes extreme numerical instability
- sigma_inv = tf.matrix_inverse(sigma)
+def get_rw_energy_fn():
+ """Get energy function for rough well distribution."""
+ # For small eta, the density underlying the rough-well energy is very close to
+ # a unit Gaussian; however, the gradient is greatly affected by the small
+ # cosine perturbations
+ eta = 1e-2
+ mu = tf.constant([0., 0.])
+ sigma = tf.constant([[1., 0.], [0., 1.]])
def energy(x):
- """Unnormalized log density/energy of 2d strongly correlated Gaussian."""
+ ip = tf.reduce_sum(x**2., axis=1)
+ return .5 * ip + eta * tf.reduce_sum(tf.cos(x / eta), axis=1)
- xmmu = x - mu
- return .5 * tf.diag_part(
- tf.matmul(tf.matmul(xmmu, sigma_inv), tf.transpose(xmmu)))
+ return energy, mu, sigma
+
+
+# Loss function
+def compute_loss(dynamics, x, scale=.1, eps=1e-4):
+ """Compute loss defined in equation (8)."""
+
+ z = tf.random_normal(tf.shape(x)) # Auxiliary variable
+ x_, _, x_accept_prob, x_out = dynamics.apply_transition(x)
+ z_, _, z_accept_prob, _ = dynamics.apply_transition(z)
+
+ # Add eps for numerical stability; following released impl
+ x_loss = tf.reduce_sum((x - x_)**2, axis=1) * x_accept_prob + eps
+ z_loss = tf.reduce_sum((z - z_)**2, axis=1) * z_accept_prob + eps
+
+ loss = tf.reduce_mean(
+ (1. / x_loss + 1. / z_loss) * scale - (x_loss + z_loss) / scale, axis=0)
+
+ return loss, x_out, x_accept_prob
+
+
+def loss_and_grads(dynamics, x, loss_fn=compute_loss):
+ """Obtain loss value and gradients."""
+ with tf.GradientTape() as tape:
+ loss_val, out, accept_prob = loss_fn(dynamics, x)
+ grads = tape.gradient(loss_val, dynamics.trainable_variables)
- return energy
+ return loss_val, grads, out, accept_prob