diff options
Diffstat (limited to 'tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py')
-rw-r--r-- | tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py | 145 |
1 files changed, 85 insertions, 60 deletions
diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py index 729d8525fa..14b8324e48 100644 --- a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py +++ b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py @@ -32,20 +32,28 @@ from tensorflow.contrib.eager.python.examples.l2hmc import neural_nets class Dynamics(tf.keras.Model): - """Dynamics engine of naive L2HMC sampler. - - Args: - x_dim: dimensionality of observed data - loglikelihood_fn: log-likelihood function of conditional probability - n_steps: number of leapfrog steps within each transition - eps: initial value learnable scale of step size - """ - - def __init__(self, x_dim, loglikelihood_fn, n_steps=25, eps=.1): + """Dynamics engine of naive L2HMC sampler.""" + + def __init__(self, + x_dim, + minus_loglikelihood_fn, + n_steps=25, + eps=.1, + np_seed=1): + """Initialization. + + Args: + x_dim: dimensionality of observed data + minus_loglikelihood_fn: log-likelihood function of conditional probability + n_steps: number of leapfrog steps within each transition + eps: initial value learnable scale of step size + np_seed: Random seed for numpy; used to control sampled masks. + """ super(Dynamics, self).__init__() + npr.seed(np_seed) self.x_dim = x_dim - self.potential = loglikelihood_fn + self.potential = minus_loglikelihood_fn self.n_steps = n_steps self._construct_time() @@ -54,7 +62,7 @@ class Dynamics(tf.keras.Model): self.position_fn = neural_nets.GenericNet(x_dim, factor=2.) self.momentum_fn = neural_nets.GenericNet(x_dim, factor=1.) - self.eps = tfe.Variable( + self.eps = tf.Variable( initial_value=eps, name="eps", dtype=tf.float32, trainable=True) def apply_transition(self, position): @@ -68,8 +76,8 @@ class Dynamics(tf.keras.Model): position, forward=False) # Decide direction uniformly - forward_mask = tf.cast( - tf.random_uniform(shape=[tf.shape(position)[0]]) > .5, tf.float32) + batch_size = tf.shape(position)[0] + forward_mask = tf.cast(tf.random_uniform((batch_size,)) > .5, tf.float32) backward_mask = 1. - forward_mask # Obtain proposed states @@ -108,7 +116,6 @@ class Dynamics(tf.keras.Model): position_post, momentum_post, logdet = lf_fn(position_post, momentum_post, i) sumlogdet += logdet - accept_prob = self._compute_accept_prob(position, momentum, position_post, momentum_post, sumlogdet) @@ -125,17 +132,17 @@ class Dynamics(tf.keras.Model): sumlogdet += logdet position, logdet = self._update_position_forward(position, momentum, t, - mask) + mask, mask_inv) sumlogdet += logdet position, logdet = self._update_position_forward(position, momentum, t, - mask_inv) + mask_inv, mask) sumlogdet += logdet momentum, logdet = self._update_momentum_forward(position, momentum, t) sumlogdet += logdet - return position, momentum, tf.reduce_sum(sumlogdet, axis=1) + return position, momentum, sumlogdet def _backward_lf(self, position, momentum, i): """One backward augmented leapfrog step. See Appendix A in paper.""" @@ -149,17 +156,17 @@ class Dynamics(tf.keras.Model): sumlogdet += logdet position, logdet = self._update_position_backward(position, momentum, t, - mask) + mask_inv, mask) sumlogdet += logdet position, logdet = self._update_position_backward(position, momentum, t, - mask_inv) + mask, mask_inv) sumlogdet += logdet momentum, logdet = self._update_momentum_backward(position, momentum, t) sumlogdet += logdet - return position, momentum, tf.reduce_sum(sumlogdet, axis=1) + return position, momentum, sumlogdet def _update_momentum_forward(self, position, momentum, t): """Update v in the forward leapfrog step.""" @@ -172,12 +179,11 @@ class Dynamics(tf.keras.Model): momentum * tf.exp(scale) - .5 * self.eps * (tf.exp(transformed) * grad - translation)) - return momentum, scale + return momentum, tf.reduce_sum(scale, axis=1) - def _update_position_forward(self, position, momentum, t, mask): + def _update_position_forward(self, position, momentum, t, mask, mask_inv): """Update x in the forward leapfrog step.""" - mask_inv = 1. - mask scale, translation, transformed = self.position_fn( [momentum, mask * position, t]) scale *= self.eps @@ -186,8 +192,7 @@ class Dynamics(tf.keras.Model): mask * position + mask_inv * (position * tf.exp(scale) + self.eps * (tf.exp(transformed) * momentum + translation))) - - return position, mask_inv * scale + return position, tf.reduce_sum(mask_inv * scale, axis=1) def _update_momentum_backward(self, position, momentum, t): """Update v in the backward leapfrog step. Inverting the forward update.""" @@ -200,21 +205,20 @@ class Dynamics(tf.keras.Model): tf.exp(scale) * (momentum + .5 * self.eps * (tf.exp(transformed) * grad - translation))) - return momentum, scale + return momentum, tf.reduce_sum(scale, axis=1) - def _update_position_backward(self, position, momentum, t, mask): + def _update_position_backward(self, position, momentum, t, mask, mask_inv): """Update x in the backward leapfrog step. Inverting the forward update.""" - mask_inv = 1. - mask scale, translation, transformed = self.position_fn( - [momentum, mask_inv * position, t]) + [momentum, mask * position, t]) scale *= -self.eps transformed *= self.eps position = ( - mask_inv * position + mask * tf.exp(scale) * - (position - self.eps * tf.exp(transformed) * momentum + translation)) + mask * position + mask_inv * tf.exp(scale) * + (position - self.eps * (tf.exp(transformed) * momentum + translation))) - return position, mask * scale + return position, tf.reduce_sum(mask_inv * scale, axis=1) def _compute_accept_prob(self, position, momentum, position_post, momentum_post, sumlogdet): @@ -222,8 +226,10 @@ class Dynamics(tf.keras.Model): old_hamil = self.hamiltonian(position, momentum) new_hamil = self.hamiltonian(position_post, momentum_post) + prob = tf.exp(tf.minimum(old_hamil - new_hamil + sumlogdet, 0.)) - return tf.exp(tf.minimum(old_hamil - new_hamil + sumlogdet, 0.)) + # Ensure numerical stability as well as correct gradients + return tf.where(tf.is_finite(prob), prob, tf.zeros_like(prob)) def _construct_time(self): """Convert leapfrog step index into sinusoidal time.""" @@ -248,6 +254,8 @@ class Dynamics(tf.keras.Model): self.masks = [] for _ in range(self.n_steps): + # Need to use npr here because tf would generated different random + # values across different `sess.run` idx = npr.permutation(np.arange(self.x_dim))[:self.x_dim // 2] mask = np.zeros((self.x_dim,)) mask[idx] = 1. @@ -273,19 +281,15 @@ class Dynamics(tf.keras.Model): def grad_potential(self, position, check_numerics=True): """Get gradient of potential function at current location.""" - if not tf.executing_eagerly(): - # TODO(lxuechen): Change this to tfe.gradients_function when it works - grad = tf.gradients(self.potential(position), position)[0] - else: + if tf.executing_eagerly(): grad = tfe.gradients_function(self.potential)(position)[0] - - if check_numerics: - return tf.check_numerics(grad, message="gradient of potential") + else: + grad = tf.gradients(self.potential(position), position)[0] return grad -# Examples of unnormalized log density/probabilities +# Examples of unnormalized log densities def get_scg_energy_fn(): """Get energy function for 2d strongly correlated Gaussian.""" @@ -295,32 +299,53 @@ def get_scg_energy_fn(): sigma_inv = tf.matrix_inverse(sigma) def energy(x): - """Unnormalized log density/energy of 2d strongly correlated Gaussian.""" + """Unnormalized minus log density of 2d strongly correlated Gaussian.""" xmmu = x - mu return .5 * tf.diag_part( tf.matmul(tf.matmul(xmmu, sigma_inv), tf.transpose(xmmu))) - return energy + return energy, mu, sigma -def get_multivariate_gaussian_energy_fn(x_dim=2): - """Get energy function for 2d strongly correlated Gaussian.""" - - mu = tf.random_normal(shape=[x_dim]) - # Lower triangularize and positive diagonal - l = tf.sigmoid( - tf.matrix_band_part(tf.random_normal(shape=[x_dim, x_dim]), -1, 0)) - # Exploit Cholesky decomposition - sigma = tf.matmul(l, tf.transpose(l)) - sigma *= 100. # Small covariance causes extreme numerical instability - sigma_inv = tf.matrix_inverse(sigma) +def get_rw_energy_fn(): + """Get energy function for rough well distribution.""" + # For small eta, the density underlying the rough-well energy is very close to + # a unit Gaussian; however, the gradient is greatly affected by the small + # cosine perturbations + eta = 1e-2 + mu = tf.constant([0., 0.]) + sigma = tf.constant([[1., 0.], [0., 1.]]) def energy(x): - """Unnormalized log density/energy of 2d strongly correlated Gaussian.""" + ip = tf.reduce_sum(x**2., axis=1) + return .5 * ip + eta * tf.reduce_sum(tf.cos(x / eta), axis=1) - xmmu = x - mu - return .5 * tf.diag_part( - tf.matmul(tf.matmul(xmmu, sigma_inv), tf.transpose(xmmu))) + return energy, mu, sigma + + +# Loss function +def compute_loss(dynamics, x, scale=.1, eps=1e-4): + """Compute loss defined in equation (8).""" + + z = tf.random_normal(tf.shape(x)) # Auxiliary variable + x_, _, x_accept_prob, x_out = dynamics.apply_transition(x) + z_, _, z_accept_prob, _ = dynamics.apply_transition(z) + + # Add eps for numerical stability; following released impl + x_loss = tf.reduce_sum((x - x_)**2, axis=1) * x_accept_prob + eps + z_loss = tf.reduce_sum((z - z_)**2, axis=1) * z_accept_prob + eps + + loss = tf.reduce_mean( + (1. / x_loss + 1. / z_loss) * scale - (x_loss + z_loss) / scale, axis=0) + + return loss, x_out, x_accept_prob + + +def loss_and_grads(dynamics, x, loss_fn=compute_loss): + """Obtain loss value and gradients.""" + with tf.GradientTape() as tape: + loss_val, out, accept_prob = loss_fn(dynamics, x) + grads = tape.gradient(loss_val, dynamics.trainable_variables) - return energy + return loss_val, grads, out, accept_prob |