aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Xuechen Li <lxuechen@google.com>2018-07-26 11:29:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-26 11:37:33 -0700
commite91c597b9aec76bb953567d84e6b92c3b2f5df8f (patch)
tree93ac8c5443e728964c6be99f862d5dce3d81d593
parenta8218323db98a504fe359568c97d0c7e1b978c47 (diff)
Make model totally defunable.
PiperOrigin-RevId: 206192038
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/blocks.py24
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/blocks_test.py15
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/main.py20
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/main_estimator.py11
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py21
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/revnet.py126
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/revnet_test.py26
7 files changed, 105 insertions, 138 deletions
diff --git a/tensorflow/contrib/eager/python/examples/revnet/blocks.py b/tensorflow/contrib/eager/python/examples/revnet/blocks.py
index 8a530b0d71..2cb04ed258 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/blocks.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/blocks.py
@@ -91,12 +91,10 @@ class RevBlock(tf.keras.Model):
h = block(h, training=training)
return h
- def backward_grads_and_vars(self, x, y, dy, training=True):
+ def backward_grads(self, x, y, dy, training=True):
"""Apply reversible block backward to outputs."""
grads_all = []
- vars_all = []
-
for i in reversed(range(len(self.blocks))):
block = self.blocks[i]
if i == 0:
@@ -104,19 +102,15 @@ class RevBlock(tf.keras.Model):
with tf.GradientTape() as tape:
tape.watch(x)
y = block(x, training=training)
-
grads_combined = tape.gradient(
y, [x] + block.trainable_variables, output_gradients=dy)
dy = grads_combined[0]
- grads_all += grads_combined[1:]
- vars_all += block.trainable_variables
+ grads_all = grads_combined[1:] + grads_all
else:
- y, dy, grads, vars_ = block.backward_grads_and_vars(
- y, dy, training=training)
- grads_all += grads
- vars_all += vars_
+ y, dy, grads = block.backward_grads(y, dy, training=training)
+ grads_all = grads + grads_all
- return dy, grads_all, vars_all
+ return dy, grads_all
class _Residual(tf.keras.Model):
@@ -195,7 +189,7 @@ class _Residual(tf.keras.Model):
return tf.concat([y1, y2], axis=self.axis)
- def backward_grads_and_vars(self, y, dy, training=True):
+ def backward_grads(self, y, dy, training=True):
"""Manually compute backward gradients given input and output grads."""
dy1, dy2 = tf.split(dy, num_or_size_splits=2, axis=self.axis)
@@ -219,13 +213,11 @@ class _Residual(tf.keras.Model):
del tape
- grads = df + dg
- vars_ = self.f.trainable_variables + self.g.trainable_variables
-
x = tf.concat([x1, x2], axis=self.axis)
dx = tf.concat([dx1, dx2], axis=self.axis)
+ grads = df + dg
- return x, dx, grads, vars_
+ return x, dx, grads
# Ideally, the following should be wrapped in `tf.keras.Sequential`, however
diff --git a/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py b/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py
index d74785c8fe..3c6ea63e48 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py
@@ -179,7 +179,7 @@ class RevBlockTest(tf.test.TestCase):
degree = compute_degree(g1, g2)
self.assertLessEqual(degree, atol)
- def test_backward_grads_and_vars_channels_first(self):
+ def test_backward_grads_channels_first(self):
"""Test `backward` function with `channels_first` data format."""
if not tf.test.is_gpu_available():
self.skipTest("GPU not available")
@@ -201,7 +201,8 @@ class RevBlockTest(tf.test.TestCase):
tape.watch(x)
y = block(x, training=True)
# Compute grads from reconstruction
- dx, dw, vars_ = block.backward_grads_and_vars(x, y, dy, training=True)
+ dx, dw = block.backward_grads(x, y, dy, training=True)
+ vars_ = block.trainable_variables
# Compute true grads
grads = tape.gradient(y, [x] + vars_, output_gradients=dy)
dx_true, dw_true = grads[0], grads[1:]
@@ -224,7 +225,8 @@ class RevBlockTest(tf.test.TestCase):
tape.watch(x)
y = block(x, training=True)
# Compute grads from reconstruction
- dx, dw, vars_ = block.backward_grads_and_vars(x, y, dy, training=True)
+ dx, dw = block.backward_grads(x, y, dy, training=True)
+ vars_ = block.trainable_variables
# Compute true grads
grads = tape.gradient(y, [x] + vars_, output_gradients=dy)
dx_true, dw_true = grads[0], grads[1:]
@@ -245,7 +247,7 @@ class _ResidualTest(tf.test.TestCase):
_validate_block_call_channels_first(blocks._Residual, self)
_validate_block_call_channels_last(blocks._Residual, self)
- def test_backward_grads_and_vars_channels_first(self):
+ def test_backward_grads_channels_first(self):
"""Test `backward_grads` function with `channels_first` data format."""
if not tf.test.is_gpu_available():
self.skipTest("GPU not available")
@@ -269,9 +271,8 @@ class _ResidualTest(tf.test.TestCase):
y = residual(x_true, training=True)
# Gradients computed due to reversibility
- x, dx, dw, vars_ = residual.backward_grads_and_vars(
- y, dy=dy, training=True)
-
+ x, dx, dw = residual.backward_grads(y, dy=dy, training=True)
+ vars_ = residual.trainable_variables
# True gradients computed by the tape
grads = tape.gradient(y, [x_true] + vars_, output_gradients=dy)
dx_true, dw_true = grads[0], grads[1:]
diff --git a/tensorflow/contrib/eager/python/examples/revnet/main.py b/tensorflow/contrib/eager/python/examples/revnet/main.py
index dcd4e1697f..b702e91f92 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/main.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/main.py
@@ -29,6 +29,11 @@ from tensorflow.contrib.eager.python.examples.revnet import revnet
tfe = tf.contrib.eager
+def apply_gradients(optimizer, grads, vars_, global_step=None):
+ """Functional style apply_grads for `tfe.defun`."""
+ optimizer.apply_gradients(zip(grads, vars_), global_step=global_step)
+
+
def main(_):
"""Eager execution workflow with RevNet trained on CIFAR-10."""
tf.enable_eager_execution()
@@ -48,6 +53,11 @@ def main(_):
if FLAGS.use_defun:
model.call = tfe.defun(model.call)
+ model.compute_gradients = tfe.defun(model.compute_gradients)
+ model.get_moving_stats = tfe.defun(model.get_moving_stats)
+ model.restore_moving_stats = tfe.defun(model.restore_moving_stats)
+ global apply_gradients # pylint:disable=global-variable-undefined
+ apply_gradients = tfe.defun(apply_gradients)
if FLAGS.train_dir:
summary_writer = tf.contrib.summary.create_file_writer(FLAGS.train_dir)
@@ -197,9 +207,13 @@ def get_datasets(data_dir, config):
def train_one_iter(model, inputs, labels, optimizer, global_step=None):
"""Train for one iteration."""
- grads, vars_, logits, loss = model.compute_gradients(
- inputs, labels, training=True)
- optimizer.apply_gradients(zip(grads, vars_), global_step=global_step)
+ logits, saved_hiddens = model(inputs, training=True)
+ values = model.get_moving_stats()
+ grads, loss = model.compute_gradients(saved_hiddens, labels)
+ # Restore moving averages when executing eagerly to avoid updating twice
+ model.restore_moving_stats(values)
+ apply_gradients(
+ optimizer, grads, model.trainable_variables, global_step=global_step)
return logits, loss
diff --git a/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py b/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py
index 4868f1931f..df25b5066f 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py
@@ -53,10 +53,10 @@ def model_fn(features, labels, mode, params):
global_step, config.lr_decay_steps, config.lr_list)
optimizer = tf.train.MomentumOptimizer(
learning_rate, momentum=config.momentum)
- grads, vars_, logits, loss = model.compute_gradients(
- inputs, labels, training=True)
+ logits, saved_hidden = model(inputs, training=True)
+ grads, loss = model.compute_gradients(saved_hidden, labels, training=True)
train_op = optimizer.apply_gradients(
- zip(grads, vars_), global_step=global_step)
+ zip(grads, model.trainable_variables), global_step=global_step)
return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
else:
@@ -130,8 +130,7 @@ def get_input_fn(config, data_dir, split):
return input_fn
-def main(argv):
- FLAGS = argv[0] # pylint:disable=invalid-name,redefined-outer-name
+def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
# RevNet specific configuration
@@ -197,4 +196,4 @@ if __name__ == "__main__":
help="[Optional] Architecture of network. "
"Other options include `revnet-110` and `revnet-164`")
FLAGS = flags.FLAGS
- tf.app.run(main=main, argv=[FLAGS])
+ tf.app.run()
diff --git a/tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py b/tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py
index d809bcd287..f0aad9b110 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py
@@ -47,7 +47,6 @@ def model_fn(features, labels, mode, params):
if isinstance(inputs, dict):
inputs = features["image"]
- FLAGS = params["FLAGS"] # pylint:disable=invalid-name,redefined-outer-name
config = params["config"]
model = revnet.RevNet(config=config)
@@ -61,14 +60,10 @@ def model_fn(features, labels, mode, params):
if FLAGS.use_tpu:
optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
- # Define gradients
- grads, vars_, logits, loss = model.compute_gradients(
- inputs, labels, training=True)
+ logits, saved_hidden = model(inputs, training=True)
+ grads, loss = model.compute_gradients(saved_hidden, labels, training=True)
train_op = optimizer.apply_gradients(
- zip(grads, vars_), global_step=global_step)
-
- names = [v.name for v in model.variables]
- tf.logging.warn("{}".format(names))
+ zip(grads, model.trainable_variables), global_step=global_step)
return tf.contrib.tpu.TPUEstimatorSpec(
mode=tf.estimator.ModeKeys.TRAIN, loss=loss, train_op=train_op)
@@ -141,8 +136,7 @@ def get_input_fn(config, data_dir, split):
return input_fn
-def main(argv):
- FLAGS = argv[0] # pylint:disable=invalid-name,redefined-outer-name
+def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
# RevNet specific configuration
@@ -177,10 +171,7 @@ def main(argv):
train_batch_size=config.tpu_batch_size,
eval_batch_size=config.tpu_eval_batch_size,
config=run_config,
- params={
- "FLAGS": FLAGS,
- "config": config,
- })
+ params={"config": config})
# Construct input functions
train_input_fn = get_input_fn(
@@ -325,4 +316,4 @@ if __name__ == "__main__":
" possible (i.e. up to --train_steps, which evaluates the model only"
" after finishing the entire training regime)."))
FLAGS = flags.FLAGS
- tf.app.run(main=main, argv=[FLAGS])
+ tf.app.run()
diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet.py b/tensorflow/contrib/eager/python/examples/revnet/revnet.py
index b1cb312b74..1f2cb14972 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/revnet.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/revnet.py
@@ -24,7 +24,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import six
import tensorflow as tf
from tensorflow.contrib.eager.python.examples.revnet import blocks
@@ -45,6 +44,7 @@ class RevNet(tf.keras.Model):
self._init_block = blocks.InitBlock(config=self.config)
self._final_block = blocks.FinalBlock(config=self.config)
self._block_list = self._construct_intermediate_blocks()
+ self._moving_average_variables = []
def _construct_intermediate_blocks(self):
# Precompute input shape after initial block
@@ -128,126 +128,90 @@ class RevNet(tf.keras.Model):
return tf.reduce_mean(cross_ent)
- def compute_gradients(self, inputs, labels, training=True, l2_reg=True):
+ def compute_gradients(self, saved_hidden, labels, training=True, l2_reg=True):
"""Manually computes gradients.
- When eager execution is enabled, this method also SILENTLY updates the
- running averages of batch normalization when `training` is set to True.
+ This method silently updates the running averages of batch normalization.
Args:
- inputs: Image tensor, either NHWC or NCHW, conforming to `data_format`
+ saved_hidden: List of hidden states Tensors
labels: One-hot labels for classification
training: Use the mini-batch stats in batch norm if set to True
l2_reg: Apply l2 regularization
Returns:
- A tuple with the first entry being a list of all gradients, the second
- entry being a list of respective variables, the third being the logits,
- and the forth being the loss
+ A tuple with the first entry being a list of all gradients and the second
+ being the loss
"""
- # Run forward pass to record hidden states
- vars_and_vals = self.get_moving_stats()
- _, saved_hidden = self(inputs, training=training) # pylint:disable=not-callable
- if tf.executing_eagerly():
- # Restore moving averages when executing eagerly to avoid updating twice
- self.restore_moving_stats(vars_and_vals)
- else:
- # Fetch batch norm updates in graph mode
- updates = self.get_updates_for(inputs)
-
- grads_all = []
- vars_all = []
+ def _defunable_pop(l):
+ """Functional style list pop that works with `tfe.defun`."""
+ t, l = l[-1], l[:-1]
+ return t, l
- # Manually backprop through last block
+ # Backprop through last block
x = saved_hidden[-1]
with tf.GradientTape() as tape:
tape.watch(x)
- # Running stats updated here
logits = self._final_block(x, training=training)
loss = self.compute_loss(logits, labels)
-
grads_combined = tape.gradient(loss,
[x] + self._final_block.trainable_variables)
- dy, grads_ = grads_combined[0], grads_combined[1:]
- grads_all += grads_
- vars_all += self._final_block.trainable_variables
+ dy, final_grads = grads_combined[0], grads_combined[1:]
- # Manually backprop through intermediate blocks
+ # Backprop through intermediate blocks
+ intermediate_grads = []
for block in reversed(self._block_list):
- y = saved_hidden.pop()
+ y, saved_hidden = _defunable_pop(saved_hidden)
x = saved_hidden[-1]
- # Running stats updated here
- dy, grads, vars_ = block.backward_grads_and_vars(
- x, y, dy, training=training)
- grads_all += grads
- vars_all += vars_
-
- # Manually backprop through first block
- saved_hidden.pop()
- x = saved_hidden.pop()
- assert not saved_hidden # Cleared after backprop
+ dy, grads = block.backward_grads(x, y, dy, training=training)
+ intermediate_grads = grads + intermediate_grads
+ # Backprop through first block
+ _, saved_hidden = _defunable_pop(saved_hidden)
+ x, saved_hidden = _defunable_pop(saved_hidden)
+ assert not saved_hidden
with tf.GradientTape() as tape:
- # Running stats updated here
y = self._init_block(x, training=training)
-
- grads_all += tape.gradient(
+ init_grads = tape.gradient(
y, self._init_block.trainable_variables, output_gradients=dy)
- vars_all += self._init_block.trainable_variables
- # Apply weight decay
+ # Ordering match up with `model.trainable_variables`
+ grads_all = init_grads + final_grads + intermediate_grads
if l2_reg:
- grads_all = self._apply_weight_decay(grads_all, vars_all)
-
- if not tf.executing_eagerly():
- # Force updates to be executed before gradient computation in graph mode
- # This does nothing when the function is wrapped in defun
- with tf.control_dependencies(updates):
- grads_all[0] = tf.identity(grads_all[0])
+ grads_all = self._apply_weight_decay(grads_all)
- return grads_all, vars_all, logits, loss
+ return grads_all, loss
- def _apply_weight_decay(self, grads, vars_):
+ def _apply_weight_decay(self, grads):
"""Update gradients to reflect weight decay."""
- # Don't decay bias
return [
g + self.config.weight_decay * v if v.name.endswith("kernel:0") else g
- for g, v in zip(grads, vars_)
+ for g, v in zip(grads, self.trainable_variables)
]
def get_moving_stats(self):
- """Get moving averages of batch normalization.
-
- This is needed to avoid updating the running average twice in one iteration.
-
- Returns:
- A dictionary mapping variables for batch normalization moving averages
- to their current values.
- """
- vars_and_vals = {}
-
- def _is_moving_var(v):
- n = v.name
- return n.endswith("moving_mean:0") or n.endswith("moving_variance:0")
+ """Get moving averages of batch normalization."""
+ device = "/gpu:0" if tf.test.is_gpu_available() else "/cpu:0"
+ with tf.device(device):
+ return [v.read_value() for v in self.moving_average_variables]
+ def restore_moving_stats(self, values):
+ """Restore moving averages of batch normalization."""
device = "/gpu:0" if tf.test.is_gpu_available() else "/cpu:0"
with tf.device(device):
- for v in filter(_is_moving_var, self.variables):
- vars_and_vals[v] = v.read_value()
+ for var_, val in zip(self.moving_average_variables, values):
+ var_.assign(val)
- return vars_and_vals
+ @property
+ def moving_average_variables(self):
+ """Get all variables that are batch norm moving averages."""
- def restore_moving_stats(self, vars_and_vals):
- """Restore moving averages of batch normalization.
+ def _is_moving_avg(v):
+ n = v.name
+ return n.endswith("moving_mean:0") or n.endswith("moving_variance:0")
- This is needed to avoid updating the running average twice in one iteration.
+ if not self._moving_average_variables:
+ self._moving_average_variables = filter(_is_moving_avg, self.variables)
- Args:
- vars_and_vals: The dictionary mapping variables to their previous values.
- """
- device = "/gpu:0" if tf.test.is_gpu_available() else "/cpu:0"
- with tf.device(device):
- for var_, val in six.iteritems(vars_and_vals):
- # `assign` causes a copy to GPU (if variable is already on GPU)
- var_.assign(val)
+ return self._moving_average_variables
diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py
index 26b0847523..84b2ddf0de 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py
@@ -31,9 +31,11 @@ tfe = tf.contrib.eager
def train_one_iter(model, inputs, labels, optimizer, global_step=None):
"""Train for one iteration."""
- grads, vars_, logits, loss = model.compute_gradients(
- inputs, labels, training=True)
- optimizer.apply_gradients(zip(grads, vars_), global_step=global_step)
+ logits, saved_hidden = model(inputs)
+ grads, loss = model.compute_gradients(
+ saved_hidden=saved_hidden, labels=labels)
+ optimizer.apply_gradients(
+ zip(grads, model.trainable_variables), global_step=global_step)
return logits, loss
@@ -96,9 +98,10 @@ class RevNetTest(tf.test.TestCase):
def test_compute_gradients(self):
"""Test `compute_gradients` function."""
- self.model(self.x, training=False) # Initialize model
- grads, vars_, logits, loss = self.model.compute_gradients(
- inputs=self.x, labels=self.t, training=True, l2_reg=True)
+ _, saved_hidden = self.model(self.x) # Initialize model
+ grads, loss = self.model.compute_gradients(
+ saved_hidden=saved_hidden, labels=self.t)
+ vars_ = self.model.trainable_variables
self.assertTrue(isinstance(grads, list))
self.assertTrue(isinstance(vars_, list))
self.assertEqual(len(grads), len(vars_))
@@ -107,7 +110,7 @@ class RevNetTest(tf.test.TestCase):
# Compare against the true gradient computed by the tape
with tf.GradientTape() as tape:
- logits, _ = self.model(self.x, training=True)
+ logits, _ = self.model(self.x)
loss_true = self.model.compute_loss(logits=logits, labels=self.t)
grads_true = tape.gradient(loss_true, vars_)
self.assertAllClose(loss, loss_true)
@@ -122,7 +125,9 @@ class RevNetTest(tf.test.TestCase):
def test_compute_gradients_defun(self):
"""Test `compute_gradients` function with defun."""
compute_gradients = tfe.defun(self.model.compute_gradients)
- grads, vars_, _, _ = compute_gradients(self.x, self.t, training=True)
+ _, saved_hidden = self.model(self.x)
+ grads, _ = compute_gradients(saved_hidden=saved_hidden, labels=self.t)
+ vars_ = self.model.trainable_variables
self.assertTrue(isinstance(grads, list))
self.assertTrue(isinstance(vars_, list))
self.assertEqual(len(grads), len(vars_))
@@ -146,10 +151,11 @@ class RevNetTest(tf.test.TestCase):
dtype=tf.int32)
global_step = tf.Variable(0., trainable=False)
model = revnet.RevNet(config=config)
- grads_all, vars_all, _, _ = model.compute_gradients(x, t, training=True)
+ _, saved_hidden = model(x)
+ grads, _ = model.compute_gradients(saved_hidden=saved_hidden, labels=t)
optimizer = tf.train.AdamOptimizer(learning_rate=1e-3)
train_op = optimizer.apply_gradients(
- zip(grads_all, vars_all), global_step=global_step)
+ zip(grads, model.trainable_variables), global_step=global_step)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())