aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/eager
diff options
context:
space:
mode:
authorGravatar Akshay Modi <nareshmodi@google.com>2018-09-24 12:12:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-24 12:17:12 -0700
commit7919d64414ed47d217b8fc508d1be56b2a531a3c (patch)
tree9d6fd19ca3932d89743ab4643a56529afc974303 /tensorflow/contrib/eager
parentf361fb8e4b4a9838e60a11ab45391c308bcb90da (diff)
Wrap forward and backward pass in a defun for L2HMC.
Also a small bugfix to handle unknown shapes in backprop._num_elements. Before: entry { name: "L2hmcBenchmark.eager_train_cpu_defun" iters: 10 wall_time: 0.594115018845 extras { key: "examples_per_sec" value { double_value: 336.635152548 } } } After: entry { name: "L2hmcBenchmark.eager_train_cpu_defun" iters: 10 wall_time: 0.322251081467 extras { key: "examples_per_sec" value { double_value: 620.634069216 } } } PiperOrigin-RevId: 214308142
Diffstat (limited to 'tensorflow/contrib/eager')
-rw-r--r--tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py26
1 files changed, 15 insertions, 11 deletions
diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py
index 9557479885..c38a1597b8 100644
--- a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py
+++ b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py
@@ -37,26 +37,32 @@ def get_default_hparams():
n_warmup_iters=3)
+def step(dynamics, optimizer, samples):
+ loss, grads, samples, _ = l2hmc.loss_and_grads(
+ dynamics, samples, loss_fn=l2hmc.compute_loss)
+ optimizer.apply_gradients(zip(grads, dynamics.variables))
+
+ return loss, samples
+
+
def warmup(dynamics,
optimizer,
n_iters=1,
n_samples=200,
- loss_fn=l2hmc.compute_loss):
+ step_fn=step):
"""Warmup optimization to reduce overhead."""
samples = tf.random_normal(
shape=[n_samples, dynamics.x_dim], dtype=tf.float32)
for _ in range(n_iters):
- _, grads, samples, _ = l2hmc.loss_and_grads(
- dynamics, samples, loss_fn=loss_fn)
- optimizer.apply_gradients(zip(grads, dynamics.variables))
+ _, samples = step_fn(dynamics, optimizer, samples)
def fit(dynamics,
samples,
optimizer,
- loss_fn=l2hmc.compute_loss,
+ step_fn=step,
n_iters=5000,
verbose=True,
logdir=None):
@@ -66,9 +72,7 @@ def fit(dynamics,
summary_writer = tf.contrib.summary.create_file_writer(logdir)
for i in range(n_iters):
- loss, grads, samples, _ = l2hmc.loss_and_grads(
- dynamics, samples, loss_fn=loss_fn)
- optimizer.apply_gradients(zip(grads, dynamics.variables))
+ loss, samples = step_fn(dynamics, optimizer, samples)
if verbose:
print("Iteration %d: loss %.4f" % (i, loss))
@@ -193,16 +197,16 @@ class L2hmcBenchmark(tf.test.Benchmark):
n_steps=hparams.n_steps,
eps=hparams.eps)
optimizer = tf.train.AdamOptimizer(learning_rate=hparams.learning_rate)
- loss_fn = tfe.defun(l2hmc.compute_loss) if defun else l2hmc.compute_loss
+ step_fn = tfe.defun(step) if defun else step
# Warmup to reduce initialization effect when timing
- warmup(dynamics, optimizer, n_iters=hparams.n_warmup_iters, loss_fn=loss_fn)
+ warmup(dynamics, optimizer, n_iters=hparams.n_warmup_iters, step_fn=step_fn)
# Training
samples = tf.random_normal(
shape=[hparams.n_samples, hparams.x_dim], dtype=tf.float32)
start_time = time.time()
- fit(dynamics, samples, optimizer, loss_fn=loss_fn, n_iters=hparams.n_iters)
+ fit(dynamics, samples, optimizer, step_fn=step_fn, n_iters=hparams.n_iters)
wall_time = time.time() - start_time
examples_per_sec = hparams.n_samples / wall_time