aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/eager/python/examples/l2hmc/main.py
blob: 45e1f98429f48749d374c2aefd8874690c3830ad (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""L2HMC on simple Gaussian mixture model with TensorFlow eager."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys

from absl import flags
import numpy as np
import tensorflow as tf
from tensorflow.contrib.eager.python.examples.l2hmc import l2hmc
try:
  import matplotlib.pyplot as plt  # pylint: disable=g-import-not-at-top
  HAS_MATPLOTLIB = True
except ImportError:
  HAS_MATPLOTLIB = False
tfe = tf.contrib.eager


def main(_):
  tf.enable_eager_execution()
  global_step = tf.train.get_or_create_global_step()
  global_step.assign(1)

  energy_fn, mean, covar = {
      "scg": l2hmc.get_scg_energy_fn(),
      "rw": l2hmc.get_rw_energy_fn()
  }[FLAGS.energy_fn]

  x_dim = 2
  train_iters = 5000
  eval_iters = 2000
  eps = 0.1
  n_steps = 10  # Chain length
  n_samples = 200
  record_loss_every = 100

  dynamics = l2hmc.Dynamics(
      x_dim=x_dim, minus_loglikelihood_fn=energy_fn, n_steps=n_steps, eps=eps)
  learning_rate = tf.train.exponential_decay(
      1e-3, global_step, 1000, 0.96, staircase=True)
  optimizer = tf.train.AdamOptimizer(learning_rate)
  checkpointer = tf.train.Checkpoint(
      optimizer=optimizer, dynamics=dynamics, global_step=global_step)

  if FLAGS.train_dir:
    summary_writer = tf.contrib.summary.create_file_writer(FLAGS.train_dir)
    if FLAGS.restore:
      latest_path = tf.train.latest_checkpoint(FLAGS.train_dir)
      checkpointer.restore(latest_path)
      print("Restored latest checkpoint at path:\"{}\" ".format(latest_path))
      sys.stdout.flush()

  if not FLAGS.restore:
    # Training
    if FLAGS.use_defun:
      # Use `tfe.deun` to boost performance when there are lots of small ops
      loss_fn = tfe.defun(l2hmc.compute_loss)
    else:
      loss_fn = l2hmc.compute_loss

    samples = tf.random_normal(shape=[n_samples, x_dim])
    for i in range(1, train_iters + 1):
      loss, samples, accept_prob = train_one_iter(
          dynamics,
          samples,
          optimizer,
          loss_fn=loss_fn,
          global_step=global_step)

      if i % record_loss_every == 0:
        print("Iteration {}, loss {:.4f}, x_accept_prob {:.4f}".format(
            i, loss.numpy(),
            accept_prob.numpy().mean()))
        if FLAGS.train_dir:
          with summary_writer.as_default():
            with tf.contrib.summary.always_record_summaries():
              tf.contrib.summary.scalar("Training loss", loss, step=global_step)
    print("Training complete.")
    sys.stdout.flush()

    if FLAGS.train_dir:
      saved_path = checkpointer.save(
          file_prefix=os.path.join(FLAGS.train_dir, "ckpt"))
      print("Saved checkpoint at path: \"{}\" ".format(saved_path))
      sys.stdout.flush()

  # Evaluation
  if FLAGS.use_defun:
    # Use tfe.deun to boost performance when there are lots of small ops
    apply_transition = tfe.defun(dynamics.apply_transition)
  else:
    apply_transition = dynamics.apply_transition

  samples = tf.random_normal(shape=[n_samples, x_dim])
  samples_history = []
  for i in range(eval_iters):
    samples_history.append(samples.numpy())
    _, _, _, samples = apply_transition(samples)
  samples_history = np.array(samples_history)
  print("Sampling complete.")
  sys.stdout.flush()

  # Mean and covariance of target distribution
  mean = mean.numpy()
  covar = covar.numpy()
  ac_spectrum = compute_ac_spectrum(samples_history, mean, covar)
  print("First 25 entries of the auto-correlation spectrum: {}".format(
      ac_spectrum[:25]))
  ess = compute_ess(ac_spectrum)
  print("Effective sample size per Metropolis-Hastings step: {}".format(ess))
  sys.stdout.flush()

  if FLAGS.train_dir:
    # Plot autocorrelation spectrum in tensorboard
    plot_step = tfe.Variable(1, trainable=False, dtype=tf.int64)

    for ac in ac_spectrum:
      with summary_writer.as_default():
        with tf.contrib.summary.always_record_summaries():
          tf.contrib.summary.scalar("Autocorrelation", ac, step=plot_step)
      plot_step.assign(plot_step + n_steps)

    if HAS_MATPLOTLIB:
      # Choose a single chain and plot the trajectory
      single_chain = samples_history[:, 0, :]
      xs = single_chain[:100, 0]
      ys = single_chain[:100, 1]
      plt.figure()
      plt.plot(xs, ys, color="orange", marker="o", alpha=0.6)  # Trained chain
      plt.savefig(os.path.join(FLAGS.train_dir, "single_chain.png"))


def train_one_iter(dynamics,
                   x,
                   optimizer,
                   loss_fn=l2hmc.compute_loss,
                   global_step=None):
  """Train the sampler for one iteration."""
  loss, grads, out, accept_prob = l2hmc.loss_and_grads(
      dynamics, x, loss_fn=loss_fn)
  optimizer.apply_gradients(
      zip(grads, dynamics.trainable_variables), global_step=global_step)

  return loss, out, accept_prob


def compute_ac_spectrum(samples_history, target_mean, target_covar):
  """Compute autocorrelation spectrum.

  Follows equation 15 from the L2HMC paper.

  Args:
    samples_history: Numpy array of shape [T, B, D], where T is the total
        number of time steps, B is the batch size, and D is the dimensionality
        of sample space.
    target_mean: 1D Numpy array of the mean of target(true) distribution.
    target_covar: 2D Numpy array representing a symmetric matrix for variance.
  Returns:
    Autocorrelation spectrum, Numpy array of shape [T-1].
  """

  # Using numpy here since eager is a bit slow due to the loop
  time_steps = samples_history.shape[0]
  trace = np.trace(target_covar)

  rhos = []
  for t in range(time_steps - 1):
    rho_t = 0.
    for tau in range(time_steps - t):
      v_tau = samples_history[tau, :, :] - target_mean
      v_tau_plus_t = samples_history[tau + t, :, :] - target_mean
      # Take dot product over observation dims and take mean over batch dims
      rho_t += np.mean(np.sum(v_tau * v_tau_plus_t, axis=1))

    rho_t /= trace * (time_steps - t)
    rhos.append(rho_t)

  return np.array(rhos)


def compute_ess(ac_spectrum):
  """Compute the effective sample size based on autocorrelation spectrum.

  This follows equation 16 from the L2HMC paper.

  Args:
    ac_spectrum: Autocorrelation spectrum
  Returns:
    The effective sample size
  """
  # Cutoff from the first value less than 0.05
  cutoff = np.argmax(ac_spectrum[1:] < .05)
  if cutoff == 0:
    cutoff = len(ac_spectrum)
  ess = 1. / (1. + 2. * np.sum(ac_spectrum[1:cutoff]))
  return ess


if __name__ == "__main__":
  flags.DEFINE_string(
      "train_dir",
      default=None,
      help="[Optional] Directory to store the training information")
  flags.DEFINE_boolean(
      "restore",
      default=False,
      help="[Optional] Restore the latest checkpoint from `train_dir` if True")
  flags.DEFINE_boolean(
      "use_defun",
      default=False,
      help="[Optional] Use `tfe.defun` to boost performance")
  flags.DEFINE_string(
      "energy_fn",
      default="scg",
      help="[Optional] The energy function used for experimentation"
      "Other options include `rw`")
  FLAGS = flags.FLAGS
  tf.app.run(main)