aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py')
-rw-r--r--tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py97
1 files changed, 27 insertions, 70 deletions
diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py
index e33b4cae4c..9557479885 100644
--- a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py
+++ b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py
@@ -37,63 +37,37 @@ def get_default_hparams():
n_warmup_iters=3)
-# Relevant functions for benchmarking
-def compute_loss(dynamics, x, scale=.1, eps=1e-4):
- """Compute loss defined in equation (8)."""
-
- z = tf.random_normal(tf.shape(x))
- x_, _, x_accept_prob, x_out = dynamics.apply_transition(x)
- z_, _, z_accept_prob, _ = dynamics.apply_transition(z)
-
- # Add eps for numerical stability; following released impl
- x_loss = tf.reduce_sum((x - x_)**2, axis=1) * x_accept_prob + eps
- z_loss = tf.reduce_sum((z - z_)**2, axis=1) * z_accept_prob + eps
-
- loss = tf.reduce_mean(
- (1. / x_loss + 1. / z_loss) * scale - (x_loss + z_loss) / scale, axis=0)
-
- return loss, x_out
-
-
-def loss_and_grads(dynamics, x, loss_fn=compute_loss):
- """Obtain loss value and gradients."""
-
- with tf.GradientTape() as tape:
- loss_val, x_out = loss_fn(dynamics, x)
- grads = tape.gradient(loss_val, dynamics.variables)
-
- return loss_val, grads, x_out
-
-
-def warmup(dynamics, optimizer, n_iters=1, n_samples=200, loss_fn=compute_loss):
+def warmup(dynamics,
+ optimizer,
+ n_iters=1,
+ n_samples=200,
+ loss_fn=l2hmc.compute_loss):
"""Warmup optimization to reduce overhead."""
samples = tf.random_normal(
shape=[n_samples, dynamics.x_dim], dtype=tf.float32)
for _ in range(n_iters):
- _, grads, samples = loss_and_grads(dynamics, samples, loss_fn=loss_fn)
+ _, grads, samples, _ = l2hmc.loss_and_grads(
+ dynamics, samples, loss_fn=loss_fn)
optimizer.apply_gradients(zip(grads, dynamics.variables))
def fit(dynamics,
samples,
optimizer,
- loss_fn=compute_loss,
+ loss_fn=l2hmc.compute_loss,
n_iters=5000,
verbose=True,
- logdir=None,
- decay_lr=True):
+ logdir=None):
"""Fit L2HMC sampler with given log-likelihood function."""
if logdir:
summary_writer = tf.contrib.summary.create_file_writer(logdir)
for i in range(n_iters):
- loss, grads, samples = loss_and_grads(dynamics, samples, loss_fn=loss_fn)
- # TODO(lxuechen): Proper learning rate decay
- if decay_lr:
- grads = [grad * .96**(i // 1000) for grad in grads]
+ loss, grads, samples, _ = l2hmc.loss_and_grads(
+ dynamics, samples, loss_fn=loss_fn)
optimizer.apply_gradients(zip(grads, dynamics.variables))
if verbose:
print("Iteration %d: loss %.4f" % (i, loss))
@@ -112,9 +86,10 @@ class L2hmcTest(tf.test.TestCase):
# Eager mode testing
hparams = get_default_hparams()
+ energy_fn, _, _ = l2hmc.get_scg_energy_fn()
dynamics = l2hmc.Dynamics(
x_dim=hparams.x_dim,
- loglikelihood_fn=l2hmc.get_scg_energy_fn(),
+ minus_loglikelihood_fn=energy_fn,
n_steps=hparams.n_steps,
eps=hparams.eps)
samples = tf.random_normal(shape=[hparams.n_samples, hparams.x_dim])
@@ -127,9 +102,10 @@ class L2hmcTest(tf.test.TestCase):
# Graph mode testing
with tf.Graph().as_default():
+ energy_fn, _, _ = l2hmc.get_scg_energy_fn()
dynamics = l2hmc.Dynamics(
x_dim=hparams.x_dim,
- loglikelihood_fn=l2hmc.get_scg_energy_fn(),
+ minus_loglikelihood_fn=energy_fn,
n_steps=hparams.n_steps,
eps=hparams.eps)
x = tf.placeholder(tf.float32, shape=[None, hparams.x_dim])
@@ -150,32 +126,20 @@ class L2hmcTest(tf.test.TestCase):
class L2hmcBenchmark(tf.test.Benchmark):
"""Eager and graph benchmarks for l2hmc."""
- def _get_energy_fn(self):
- """Get specific energy function according to FLAGS."""
-
- if FLAGS.energy_fn == "scg":
- energy_fn = l2hmc.get_scg_energy_fn()
- elif FLAGS.energy_fn == "multivariate_gaussian":
- energy_fn = l2hmc.get_multivariate_gaussian_energy_fn(x_dim=FLAGS.x_dim)
- else:
- raise ValueError("No such energy function %s" % FLAGS.energy_fn)
-
- return energy_fn
-
def benchmark_graph(self):
"""Benchmark Graph performance."""
hparams = get_default_hparams()
tf.reset_default_graph()
with tf.Graph().as_default():
- energy_fn = self._get_energy_fn()
+ energy_fn, _, _ = l2hmc.get_scg_energy_fn()
dynamics = l2hmc.Dynamics(
x_dim=hparams.x_dim,
- loglikelihood_fn=energy_fn,
+ minus_loglikelihood_fn=energy_fn,
n_steps=hparams.n_steps,
eps=hparams.eps)
x = tf.placeholder(tf.float32, shape=[None, hparams.x_dim])
- loss, x_out = compute_loss(dynamics, x)
+ loss, x_out, _ = l2hmc.compute_loss(dynamics, x)
global_step = tf.Variable(0., name="global_step", trainable=False)
learning_rate = tf.train.exponential_decay(
@@ -183,7 +147,11 @@ class L2hmcBenchmark(tf.test.Benchmark):
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
train_op = optimizer.minimize(loss, global_step=global_step)
- with tf.Session() as sess:
+ # Single thread; fairer comparison against eager
+ session_conf = tf.ConfigProto(
+ intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
+
+ with tf.Session(config=session_conf) as sess:
sess.run(tf.global_variables_initializer())
# Warmup to reduce initialization effect when timing
@@ -218,14 +186,14 @@ class L2hmcBenchmark(tf.test.Benchmark):
"""Benchmark Eager performance."""
hparams = get_default_hparams()
- energy_fn = self._get_energy_fn()
+ energy_fn, _, _ = l2hmc.get_scg_energy_fn()
dynamics = l2hmc.Dynamics(
x_dim=hparams.x_dim,
- loglikelihood_fn=energy_fn,
+ minus_loglikelihood_fn=energy_fn,
n_steps=hparams.n_steps,
eps=hparams.eps)
optimizer = tf.train.AdamOptimizer(learning_rate=hparams.learning_rate)
- loss_fn = tfe.defun(compute_loss) if defun else compute_loss
+ loss_fn = tfe.defun(l2hmc.compute_loss) if defun else l2hmc.compute_loss
# Warmup to reduce initialization effect when timing
warmup(dynamics, optimizer, n_iters=hparams.n_warmup_iters, loss_fn=loss_fn)
@@ -234,12 +202,7 @@ class L2hmcBenchmark(tf.test.Benchmark):
samples = tf.random_normal(
shape=[hparams.n_samples, hparams.x_dim], dtype=tf.float32)
start_time = time.time()
- fit(dynamics,
- samples,
- optimizer,
- loss_fn=loss_fn,
- n_iters=hparams.n_iters,
- decay_lr=True)
+ fit(dynamics, samples, optimizer, loss_fn=loss_fn, n_iters=hparams.n_iters)
wall_time = time.time() - start_time
examples_per_sec = hparams.n_samples / wall_time
@@ -251,14 +214,8 @@ class L2hmcBenchmark(tf.test.Benchmark):
wall_time=wall_time)
del dynamics
- del loss_fn
if __name__ == "__main__":
- tf.flags.DEFINE_string("energy_fn", "scg",
- ("The energy function/unnormalized log-probability. "
- "Either be `scg` or `multivariate_gaussian`"))
- tf.flags.DEFINE_integer("x_dim", 2, "Dimensionality of observation space.")
- FLAGS = tf.flags.FLAGS
tf.enable_eager_execution()
tf.test.main()