diff options
author | 2018-07-26 11:29:10 -0700 | |
---|---|---|
committer | 2018-07-26 11:37:33 -0700 | |
commit | e91c597b9aec76bb953567d84e6b92c3b2f5df8f (patch) | |
tree | 93ac8c5443e728964c6be99f862d5dce3d81d593 | |
parent | a8218323db98a504fe359568c97d0c7e1b978c47 (diff) |
Make model totally defunable.
PiperOrigin-RevId: 206192038
7 files changed, 105 insertions, 138 deletions
diff --git a/tensorflow/contrib/eager/python/examples/revnet/blocks.py b/tensorflow/contrib/eager/python/examples/revnet/blocks.py index 8a530b0d71..2cb04ed258 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/blocks.py +++ b/tensorflow/contrib/eager/python/examples/revnet/blocks.py @@ -91,12 +91,10 @@ class RevBlock(tf.keras.Model): h = block(h, training=training) return h - def backward_grads_and_vars(self, x, y, dy, training=True): + def backward_grads(self, x, y, dy, training=True): """Apply reversible block backward to outputs.""" grads_all = [] - vars_all = [] - for i in reversed(range(len(self.blocks))): block = self.blocks[i] if i == 0: @@ -104,19 +102,15 @@ class RevBlock(tf.keras.Model): with tf.GradientTape() as tape: tape.watch(x) y = block(x, training=training) - grads_combined = tape.gradient( y, [x] + block.trainable_variables, output_gradients=dy) dy = grads_combined[0] - grads_all += grads_combined[1:] - vars_all += block.trainable_variables + grads_all = grads_combined[1:] + grads_all else: - y, dy, grads, vars_ = block.backward_grads_and_vars( - y, dy, training=training) - grads_all += grads - vars_all += vars_ + y, dy, grads = block.backward_grads(y, dy, training=training) + grads_all = grads + grads_all - return dy, grads_all, vars_all + return dy, grads_all class _Residual(tf.keras.Model): @@ -195,7 +189,7 @@ class _Residual(tf.keras.Model): return tf.concat([y1, y2], axis=self.axis) - def backward_grads_and_vars(self, y, dy, training=True): + def backward_grads(self, y, dy, training=True): """Manually compute backward gradients given input and output grads.""" dy1, dy2 = tf.split(dy, num_or_size_splits=2, axis=self.axis) @@ -219,13 +213,11 @@ class _Residual(tf.keras.Model): del tape - grads = df + dg - vars_ = self.f.trainable_variables + self.g.trainable_variables - x = tf.concat([x1, x2], axis=self.axis) dx = tf.concat([dx1, dx2], axis=self.axis) + grads = df + dg - return x, dx, grads, vars_ + return x, dx, grads # Ideally, the following should be wrapped in `tf.keras.Sequential`, however diff --git a/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py b/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py index d74785c8fe..3c6ea63e48 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py +++ b/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py @@ -179,7 +179,7 @@ class RevBlockTest(tf.test.TestCase): degree = compute_degree(g1, g2) self.assertLessEqual(degree, atol) - def test_backward_grads_and_vars_channels_first(self): + def test_backward_grads_channels_first(self): """Test `backward` function with `channels_first` data format.""" if not tf.test.is_gpu_available(): self.skipTest("GPU not available") @@ -201,7 +201,8 @@ class RevBlockTest(tf.test.TestCase): tape.watch(x) y = block(x, training=True) # Compute grads from reconstruction - dx, dw, vars_ = block.backward_grads_and_vars(x, y, dy, training=True) + dx, dw = block.backward_grads(x, y, dy, training=True) + vars_ = block.trainable_variables # Compute true grads grads = tape.gradient(y, [x] + vars_, output_gradients=dy) dx_true, dw_true = grads[0], grads[1:] @@ -224,7 +225,8 @@ class RevBlockTest(tf.test.TestCase): tape.watch(x) y = block(x, training=True) # Compute grads from reconstruction - dx, dw, vars_ = block.backward_grads_and_vars(x, y, dy, training=True) + dx, dw = block.backward_grads(x, y, dy, training=True) + vars_ = block.trainable_variables # Compute true grads grads = tape.gradient(y, [x] + vars_, output_gradients=dy) dx_true, dw_true = grads[0], grads[1:] @@ -245,7 +247,7 @@ class _ResidualTest(tf.test.TestCase): _validate_block_call_channels_first(blocks._Residual, self) _validate_block_call_channels_last(blocks._Residual, self) - def test_backward_grads_and_vars_channels_first(self): + def test_backward_grads_channels_first(self): """Test `backward_grads` function with `channels_first` data format.""" if not tf.test.is_gpu_available(): self.skipTest("GPU not available") @@ -269,9 +271,8 @@ class _ResidualTest(tf.test.TestCase): y = residual(x_true, training=True) # Gradients computed due to reversibility - x, dx, dw, vars_ = residual.backward_grads_and_vars( - y, dy=dy, training=True) - + x, dx, dw = residual.backward_grads(y, dy=dy, training=True) + vars_ = residual.trainable_variables # True gradients computed by the tape grads = tape.gradient(y, [x_true] + vars_, output_gradients=dy) dx_true, dw_true = grads[0], grads[1:] diff --git a/tensorflow/contrib/eager/python/examples/revnet/main.py b/tensorflow/contrib/eager/python/examples/revnet/main.py index dcd4e1697f..b702e91f92 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/main.py +++ b/tensorflow/contrib/eager/python/examples/revnet/main.py @@ -29,6 +29,11 @@ from tensorflow.contrib.eager.python.examples.revnet import revnet tfe = tf.contrib.eager +def apply_gradients(optimizer, grads, vars_, global_step=None): + """Functional style apply_grads for `tfe.defun`.""" + optimizer.apply_gradients(zip(grads, vars_), global_step=global_step) + + def main(_): """Eager execution workflow with RevNet trained on CIFAR-10.""" tf.enable_eager_execution() @@ -48,6 +53,11 @@ def main(_): if FLAGS.use_defun: model.call = tfe.defun(model.call) + model.compute_gradients = tfe.defun(model.compute_gradients) + model.get_moving_stats = tfe.defun(model.get_moving_stats) + model.restore_moving_stats = tfe.defun(model.restore_moving_stats) + global apply_gradients # pylint:disable=global-variable-undefined + apply_gradients = tfe.defun(apply_gradients) if FLAGS.train_dir: summary_writer = tf.contrib.summary.create_file_writer(FLAGS.train_dir) @@ -197,9 +207,13 @@ def get_datasets(data_dir, config): def train_one_iter(model, inputs, labels, optimizer, global_step=None): """Train for one iteration.""" - grads, vars_, logits, loss = model.compute_gradients( - inputs, labels, training=True) - optimizer.apply_gradients(zip(grads, vars_), global_step=global_step) + logits, saved_hiddens = model(inputs, training=True) + values = model.get_moving_stats() + grads, loss = model.compute_gradients(saved_hiddens, labels) + # Restore moving averages when executing eagerly to avoid updating twice + model.restore_moving_stats(values) + apply_gradients( + optimizer, grads, model.trainable_variables, global_step=global_step) return logits, loss diff --git a/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py b/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py index 4868f1931f..df25b5066f 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py +++ b/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py @@ -53,10 +53,10 @@ def model_fn(features, labels, mode, params): global_step, config.lr_decay_steps, config.lr_list) optimizer = tf.train.MomentumOptimizer( learning_rate, momentum=config.momentum) - grads, vars_, logits, loss = model.compute_gradients( - inputs, labels, training=True) + logits, saved_hidden = model(inputs, training=True) + grads, loss = model.compute_gradients(saved_hidden, labels, training=True) train_op = optimizer.apply_gradients( - zip(grads, vars_), global_step=global_step) + zip(grads, model.trainable_variables), global_step=global_step) return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op) else: @@ -130,8 +130,7 @@ def get_input_fn(config, data_dir, split): return input_fn -def main(argv): - FLAGS = argv[0] # pylint:disable=invalid-name,redefined-outer-name +def main(_): tf.logging.set_verbosity(tf.logging.INFO) # RevNet specific configuration @@ -197,4 +196,4 @@ if __name__ == "__main__": help="[Optional] Architecture of network. " "Other options include `revnet-110` and `revnet-164`") FLAGS = flags.FLAGS - tf.app.run(main=main, argv=[FLAGS]) + tf.app.run() diff --git a/tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py b/tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py index d809bcd287..f0aad9b110 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py +++ b/tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py @@ -47,7 +47,6 @@ def model_fn(features, labels, mode, params): if isinstance(inputs, dict): inputs = features["image"] - FLAGS = params["FLAGS"] # pylint:disable=invalid-name,redefined-outer-name config = params["config"] model = revnet.RevNet(config=config) @@ -61,14 +60,10 @@ def model_fn(features, labels, mode, params): if FLAGS.use_tpu: optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) - # Define gradients - grads, vars_, logits, loss = model.compute_gradients( - inputs, labels, training=True) + logits, saved_hidden = model(inputs, training=True) + grads, loss = model.compute_gradients(saved_hidden, labels, training=True) train_op = optimizer.apply_gradients( - zip(grads, vars_), global_step=global_step) - - names = [v.name for v in model.variables] - tf.logging.warn("{}".format(names)) + zip(grads, model.trainable_variables), global_step=global_step) return tf.contrib.tpu.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=loss, train_op=train_op) @@ -141,8 +136,7 @@ def get_input_fn(config, data_dir, split): return input_fn -def main(argv): - FLAGS = argv[0] # pylint:disable=invalid-name,redefined-outer-name +def main(_): tf.logging.set_verbosity(tf.logging.INFO) # RevNet specific configuration @@ -177,10 +171,7 @@ def main(argv): train_batch_size=config.tpu_batch_size, eval_batch_size=config.tpu_eval_batch_size, config=run_config, - params={ - "FLAGS": FLAGS, - "config": config, - }) + params={"config": config}) # Construct input functions train_input_fn = get_input_fn( @@ -325,4 +316,4 @@ if __name__ == "__main__": " possible (i.e. up to --train_steps, which evaluates the model only" " after finishing the entire training regime).")) FLAGS = flags.FLAGS - tf.app.run(main=main, argv=[FLAGS]) + tf.app.run() diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet.py b/tensorflow/contrib/eager/python/examples/revnet/revnet.py index b1cb312b74..1f2cb14972 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/revnet.py +++ b/tensorflow/contrib/eager/python/examples/revnet/revnet.py @@ -24,7 +24,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import six import tensorflow as tf from tensorflow.contrib.eager.python.examples.revnet import blocks @@ -45,6 +44,7 @@ class RevNet(tf.keras.Model): self._init_block = blocks.InitBlock(config=self.config) self._final_block = blocks.FinalBlock(config=self.config) self._block_list = self._construct_intermediate_blocks() + self._moving_average_variables = [] def _construct_intermediate_blocks(self): # Precompute input shape after initial block @@ -128,126 +128,90 @@ class RevNet(tf.keras.Model): return tf.reduce_mean(cross_ent) - def compute_gradients(self, inputs, labels, training=True, l2_reg=True): + def compute_gradients(self, saved_hidden, labels, training=True, l2_reg=True): """Manually computes gradients. - When eager execution is enabled, this method also SILENTLY updates the - running averages of batch normalization when `training` is set to True. + This method silently updates the running averages of batch normalization. Args: - inputs: Image tensor, either NHWC or NCHW, conforming to `data_format` + saved_hidden: List of hidden states Tensors labels: One-hot labels for classification training: Use the mini-batch stats in batch norm if set to True l2_reg: Apply l2 regularization Returns: - A tuple with the first entry being a list of all gradients, the second - entry being a list of respective variables, the third being the logits, - and the forth being the loss + A tuple with the first entry being a list of all gradients and the second + being the loss """ - # Run forward pass to record hidden states - vars_and_vals = self.get_moving_stats() - _, saved_hidden = self(inputs, training=training) # pylint:disable=not-callable - if tf.executing_eagerly(): - # Restore moving averages when executing eagerly to avoid updating twice - self.restore_moving_stats(vars_and_vals) - else: - # Fetch batch norm updates in graph mode - updates = self.get_updates_for(inputs) - - grads_all = [] - vars_all = [] + def _defunable_pop(l): + """Functional style list pop that works with `tfe.defun`.""" + t, l = l[-1], l[:-1] + return t, l - # Manually backprop through last block + # Backprop through last block x = saved_hidden[-1] with tf.GradientTape() as tape: tape.watch(x) - # Running stats updated here logits = self._final_block(x, training=training) loss = self.compute_loss(logits, labels) - grads_combined = tape.gradient(loss, [x] + self._final_block.trainable_variables) - dy, grads_ = grads_combined[0], grads_combined[1:] - grads_all += grads_ - vars_all += self._final_block.trainable_variables + dy, final_grads = grads_combined[0], grads_combined[1:] - # Manually backprop through intermediate blocks + # Backprop through intermediate blocks + intermediate_grads = [] for block in reversed(self._block_list): - y = saved_hidden.pop() + y, saved_hidden = _defunable_pop(saved_hidden) x = saved_hidden[-1] - # Running stats updated here - dy, grads, vars_ = block.backward_grads_and_vars( - x, y, dy, training=training) - grads_all += grads - vars_all += vars_ - - # Manually backprop through first block - saved_hidden.pop() - x = saved_hidden.pop() - assert not saved_hidden # Cleared after backprop + dy, grads = block.backward_grads(x, y, dy, training=training) + intermediate_grads = grads + intermediate_grads + # Backprop through first block + _, saved_hidden = _defunable_pop(saved_hidden) + x, saved_hidden = _defunable_pop(saved_hidden) + assert not saved_hidden with tf.GradientTape() as tape: - # Running stats updated here y = self._init_block(x, training=training) - - grads_all += tape.gradient( + init_grads = tape.gradient( y, self._init_block.trainable_variables, output_gradients=dy) - vars_all += self._init_block.trainable_variables - # Apply weight decay + # Ordering match up with `model.trainable_variables` + grads_all = init_grads + final_grads + intermediate_grads if l2_reg: - grads_all = self._apply_weight_decay(grads_all, vars_all) - - if not tf.executing_eagerly(): - # Force updates to be executed before gradient computation in graph mode - # This does nothing when the function is wrapped in defun - with tf.control_dependencies(updates): - grads_all[0] = tf.identity(grads_all[0]) + grads_all = self._apply_weight_decay(grads_all) - return grads_all, vars_all, logits, loss + return grads_all, loss - def _apply_weight_decay(self, grads, vars_): + def _apply_weight_decay(self, grads): """Update gradients to reflect weight decay.""" - # Don't decay bias return [ g + self.config.weight_decay * v if v.name.endswith("kernel:0") else g - for g, v in zip(grads, vars_) + for g, v in zip(grads, self.trainable_variables) ] def get_moving_stats(self): - """Get moving averages of batch normalization. - - This is needed to avoid updating the running average twice in one iteration. - - Returns: - A dictionary mapping variables for batch normalization moving averages - to their current values. - """ - vars_and_vals = {} - - def _is_moving_var(v): - n = v.name - return n.endswith("moving_mean:0") or n.endswith("moving_variance:0") + """Get moving averages of batch normalization.""" + device = "/gpu:0" if tf.test.is_gpu_available() else "/cpu:0" + with tf.device(device): + return [v.read_value() for v in self.moving_average_variables] + def restore_moving_stats(self, values): + """Restore moving averages of batch normalization.""" device = "/gpu:0" if tf.test.is_gpu_available() else "/cpu:0" with tf.device(device): - for v in filter(_is_moving_var, self.variables): - vars_and_vals[v] = v.read_value() + for var_, val in zip(self.moving_average_variables, values): + var_.assign(val) - return vars_and_vals + @property + def moving_average_variables(self): + """Get all variables that are batch norm moving averages.""" - def restore_moving_stats(self, vars_and_vals): - """Restore moving averages of batch normalization. + def _is_moving_avg(v): + n = v.name + return n.endswith("moving_mean:0") or n.endswith("moving_variance:0") - This is needed to avoid updating the running average twice in one iteration. + if not self._moving_average_variables: + self._moving_average_variables = filter(_is_moving_avg, self.variables) - Args: - vars_and_vals: The dictionary mapping variables to their previous values. - """ - device = "/gpu:0" if tf.test.is_gpu_available() else "/cpu:0" - with tf.device(device): - for var_, val in six.iteritems(vars_and_vals): - # `assign` causes a copy to GPU (if variable is already on GPU) - var_.assign(val) + return self._moving_average_variables diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py index 26b0847523..84b2ddf0de 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py +++ b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py @@ -31,9 +31,11 @@ tfe = tf.contrib.eager def train_one_iter(model, inputs, labels, optimizer, global_step=None): """Train for one iteration.""" - grads, vars_, logits, loss = model.compute_gradients( - inputs, labels, training=True) - optimizer.apply_gradients(zip(grads, vars_), global_step=global_step) + logits, saved_hidden = model(inputs) + grads, loss = model.compute_gradients( + saved_hidden=saved_hidden, labels=labels) + optimizer.apply_gradients( + zip(grads, model.trainable_variables), global_step=global_step) return logits, loss @@ -96,9 +98,10 @@ class RevNetTest(tf.test.TestCase): def test_compute_gradients(self): """Test `compute_gradients` function.""" - self.model(self.x, training=False) # Initialize model - grads, vars_, logits, loss = self.model.compute_gradients( - inputs=self.x, labels=self.t, training=True, l2_reg=True) + _, saved_hidden = self.model(self.x) # Initialize model + grads, loss = self.model.compute_gradients( + saved_hidden=saved_hidden, labels=self.t) + vars_ = self.model.trainable_variables self.assertTrue(isinstance(grads, list)) self.assertTrue(isinstance(vars_, list)) self.assertEqual(len(grads), len(vars_)) @@ -107,7 +110,7 @@ class RevNetTest(tf.test.TestCase): # Compare against the true gradient computed by the tape with tf.GradientTape() as tape: - logits, _ = self.model(self.x, training=True) + logits, _ = self.model(self.x) loss_true = self.model.compute_loss(logits=logits, labels=self.t) grads_true = tape.gradient(loss_true, vars_) self.assertAllClose(loss, loss_true) @@ -122,7 +125,9 @@ class RevNetTest(tf.test.TestCase): def test_compute_gradients_defun(self): """Test `compute_gradients` function with defun.""" compute_gradients = tfe.defun(self.model.compute_gradients) - grads, vars_, _, _ = compute_gradients(self.x, self.t, training=True) + _, saved_hidden = self.model(self.x) + grads, _ = compute_gradients(saved_hidden=saved_hidden, labels=self.t) + vars_ = self.model.trainable_variables self.assertTrue(isinstance(grads, list)) self.assertTrue(isinstance(vars_, list)) self.assertEqual(len(grads), len(vars_)) @@ -146,10 +151,11 @@ class RevNetTest(tf.test.TestCase): dtype=tf.int32) global_step = tf.Variable(0., trainable=False) model = revnet.RevNet(config=config) - grads_all, vars_all, _, _ = model.compute_gradients(x, t, training=True) + _, saved_hidden = model(x) + grads, _ = model.compute_gradients(saved_hidden=saved_hidden, labels=t) optimizer = tf.train.AdamOptimizer(learning_rate=1e-3) train_op = optimizer.apply_gradients( - zip(grads_all, vars_all), global_step=global_step) + zip(grads, model.trainable_variables), global_step=global_step) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) |