diff options
author | Akshay Modi <nareshmodi@google.com> | 2018-09-24 12:12:30 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-24 12:17:12 -0700 |
commit | 7919d64414ed47d217b8fc508d1be56b2a531a3c (patch) | |
tree | 9d6fd19ca3932d89743ab4643a56529afc974303 /tensorflow/contrib/eager | |
parent | f361fb8e4b4a9838e60a11ab45391c308bcb90da (diff) |
Wrap forward and backward pass in a defun for L2HMC.
Also a small bugfix to handle unknown shapes in backprop._num_elements.
Before:
entry {
name: "L2hmcBenchmark.eager_train_cpu_defun"
iters: 10
wall_time: 0.594115018845
extras {
key: "examples_per_sec"
value {
double_value: 336.635152548
}
}
}
After:
entry {
name: "L2hmcBenchmark.eager_train_cpu_defun"
iters: 10
wall_time: 0.322251081467
extras {
key: "examples_per_sec"
value {
double_value: 620.634069216
}
}
}
PiperOrigin-RevId: 214308142
Diffstat (limited to 'tensorflow/contrib/eager')
-rw-r--r-- | tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py | 26 |
1 files changed, 15 insertions, 11 deletions
diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py index 9557479885..c38a1597b8 100644 --- a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py +++ b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py @@ -37,26 +37,32 @@ def get_default_hparams(): n_warmup_iters=3) +def step(dynamics, optimizer, samples): + loss, grads, samples, _ = l2hmc.loss_and_grads( + dynamics, samples, loss_fn=l2hmc.compute_loss) + optimizer.apply_gradients(zip(grads, dynamics.variables)) + + return loss, samples + + def warmup(dynamics, optimizer, n_iters=1, n_samples=200, - loss_fn=l2hmc.compute_loss): + step_fn=step): """Warmup optimization to reduce overhead.""" samples = tf.random_normal( shape=[n_samples, dynamics.x_dim], dtype=tf.float32) for _ in range(n_iters): - _, grads, samples, _ = l2hmc.loss_and_grads( - dynamics, samples, loss_fn=loss_fn) - optimizer.apply_gradients(zip(grads, dynamics.variables)) + _, samples = step_fn(dynamics, optimizer, samples) def fit(dynamics, samples, optimizer, - loss_fn=l2hmc.compute_loss, + step_fn=step, n_iters=5000, verbose=True, logdir=None): @@ -66,9 +72,7 @@ def fit(dynamics, summary_writer = tf.contrib.summary.create_file_writer(logdir) for i in range(n_iters): - loss, grads, samples, _ = l2hmc.loss_and_grads( - dynamics, samples, loss_fn=loss_fn) - optimizer.apply_gradients(zip(grads, dynamics.variables)) + loss, samples = step_fn(dynamics, optimizer, samples) if verbose: print("Iteration %d: loss %.4f" % (i, loss)) @@ -193,16 +197,16 @@ class L2hmcBenchmark(tf.test.Benchmark): n_steps=hparams.n_steps, eps=hparams.eps) optimizer = tf.train.AdamOptimizer(learning_rate=hparams.learning_rate) - loss_fn = tfe.defun(l2hmc.compute_loss) if defun else l2hmc.compute_loss + step_fn = tfe.defun(step) if defun else step # Warmup to reduce initialization effect when timing - warmup(dynamics, optimizer, n_iters=hparams.n_warmup_iters, loss_fn=loss_fn) + warmup(dynamics, optimizer, n_iters=hparams.n_warmup_iters, step_fn=step_fn) # Training samples = tf.random_normal( shape=[hparams.n_samples, hparams.x_dim], dtype=tf.float32) start_time = time.time() - fit(dynamics, samples, optimizer, loss_fn=loss_fn, n_iters=hparams.n_iters) + fit(dynamics, samples, optimizer, step_fn=step_fn, n_iters=hparams.n_iters) wall_time = time.time() - start_time examples_per_sec = hparams.n_samples / wall_time |