diff options
Diffstat (limited to 'tensorflow/contrib/eager/python/examples/revnet/revnet.py')
-rw-r--r-- | tensorflow/contrib/eager/python/examples/revnet/revnet.py | 149 |
1 files changed, 66 insertions, 83 deletions
diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet.py b/tensorflow/contrib/eager/python/examples/revnet/revnet.py index 0228bff6fa..b1cb312b74 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/revnet.py +++ b/tensorflow/contrib/eager/python/examples/revnet/revnet.py @@ -24,9 +24,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import functools -import operator - import six import tensorflow as tf from tensorflow.contrib.eager.python.examples.revnet import blocks @@ -45,66 +42,9 @@ class RevNet(tf.keras.Model): self.axis = 1 if config.data_format == "channels_first" else 3 self.config = config - self._init_block = self._construct_init_block() + self._init_block = blocks.InitBlock(config=self.config) + self._final_block = blocks.FinalBlock(config=self.config) self._block_list = self._construct_intermediate_blocks() - self._final_block = self._construct_final_block() - self._moving_stats_vars = None - - def _construct_init_block(self): - init_block = tf.keras.Sequential( - [ - tf.keras.layers.Conv2D( - filters=self.config.init_filters, - kernel_size=self.config.init_kernel, - strides=(self.config.init_stride, self.config.init_stride), - data_format=self.config.data_format, - use_bias=False, - padding="SAME", - input_shape=self.config.input_shape), - tf.keras.layers.BatchNormalization( - axis=self.axis, fused=self.config.fused), - tf.keras.layers.Activation("relu"), - ], - name="init") - if self.config.init_max_pool: - init_block.add( - tf.keras.layers.MaxPooling2D( - pool_size=(3, 3), - strides=(2, 2), - padding="SAME", - data_format=self.config.data_format)) - return init_block - - def _construct_final_block(self): - f = self.config.filters[-1] # Number of filters - r = functools.reduce(operator.mul, self.config.strides, 1) # Reduce ratio - r *= self.config.init_stride - if self.config.init_max_pool: - r *= 2 - - if self.config.data_format == "channels_first": - w, h = self.config.input_shape[1], self.config.input_shape[2] - input_shape = (f, w // r, h // r) - elif self.config.data_format == "channels_last": - w, h = self.config.input_shape[0], self.config.input_shape[1] - input_shape = (w // r, h // r, f) - else: - raise ValueError("Data format should be either `channels_first`" - " or `channels_last`") - - final_block = tf.keras.Sequential( - [ - tf.keras.layers.BatchNormalization( - axis=self.axis, - input_shape=input_shape, - fused=self.config.fused), - tf.keras.layers.Activation("relu"), - tf.keras.layers.GlobalAveragePooling2D( - data_format=self.config.data_format), - tf.keras.layers.Dense(self.config.n_classes) - ], - name="final") - return final_block def _construct_intermediate_blocks(self): # Precompute input shape after initial block @@ -139,7 +79,8 @@ class RevNet(tf.keras.Model): batch_norm_first=(i != 0), # Only skip on first block data_format=self.config.data_format, bottleneck=self.config.bottleneck, - fused=self.config.fused) + fused=self.config.fused, + dtype=self.config.dtype) block_list.append(rev_block) # Precompute input shape for the next block @@ -174,30 +115,46 @@ class RevNet(tf.keras.Model): def compute_loss(self, logits, labels): """Compute cross entropy loss.""" - cross_ent = tf.nn.sparse_softmax_cross_entropy_with_logits( - logits=logits, labels=labels) + if self.config.dtype == tf.float32 or self.config.dtype == tf.float16: + cross_ent = tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=logits, labels=labels) + else: + # `sparse_softmax_cross_entropy_with_logits` does not have a GPU kernel + # for float64, int32 pairs + labels = tf.one_hot( + labels, depth=self.config.n_classes, axis=1, dtype=self.config.dtype) + cross_ent = tf.nn.softmax_cross_entropy_with_logits( + logits=logits, labels=labels) return tf.reduce_mean(cross_ent) - def compute_gradients(self, inputs, labels, training=True): + def compute_gradients(self, inputs, labels, training=True, l2_reg=True): """Manually computes gradients. - This method also SILENTLY updates the running averages of batch - normalization when `training` is set to True. + When eager execution is enabled, this method also SILENTLY updates the + running averages of batch normalization when `training` is set to True. Args: inputs: Image tensor, either NHWC or NCHW, conforming to `data_format` labels: One-hot labels for classification training: Use the mini-batch stats in batch norm if set to True + l2_reg: Apply l2 regularization Returns: - list of tuples each being (grad, var) for optimizer to use + A tuple with the first entry being a list of all gradients, the second + entry being a list of respective variables, the third being the logits, + and the forth being the loss """ - # Run forward pass to record hidden states; avoid updating running averages + # Run forward pass to record hidden states vars_and_vals = self.get_moving_stats() - _, saved_hidden = self.call(inputs, training=training) - self.restore_moving_stats(vars_and_vals) + _, saved_hidden = self(inputs, training=training) # pylint:disable=not-callable + if tf.executing_eagerly(): + # Restore moving averages when executing eagerly to avoid updating twice + self.restore_moving_stats(vars_and_vals) + else: + # Fetch batch norm updates in graph mode + updates = self.get_updates_for(inputs) grads_all = [] vars_all = [] @@ -205,9 +162,8 @@ class RevNet(tf.keras.Model): # Manually backprop through last block x = saved_hidden[-1] with tf.GradientTape() as tape: - x = tf.identity(x) # TODO(lxuechen): Remove after b/110264016 is fixed tape.watch(x) - # Running stats updated below + # Running stats updated here logits = self._final_block(x, training=training) loss = self.compute_loss(logits, labels) @@ -221,6 +177,7 @@ class RevNet(tf.keras.Model): for block in reversed(self._block_list): y = saved_hidden.pop() x = saved_hidden[-1] + # Running stats updated here dy, grads, vars_ = block.backward_grads_and_vars( x, y, dy, training=training) grads_all += grads @@ -232,18 +189,24 @@ class RevNet(tf.keras.Model): assert not saved_hidden # Cleared after backprop with tf.GradientTape() as tape: - x = tf.identity(x) # TODO(lxuechen): Remove after b/110264016 is fixed - # Running stats updated below + # Running stats updated here y = self._init_block(x, training=training) grads_all += tape.gradient( - y, self._init_block.trainable_variables, output_gradients=[dy]) + y, self._init_block.trainable_variables, output_gradients=dy) vars_all += self._init_block.trainable_variables # Apply weight decay - grads_all = self._apply_weight_decay(grads_all, vars_all) + if l2_reg: + grads_all = self._apply_weight_decay(grads_all, vars_all) - return grads_all, vars_all, loss + if not tf.executing_eagerly(): + # Force updates to be executed before gradient computation in graph mode + # This does nothing when the function is wrapped in defun + with tf.control_dependencies(updates): + grads_all[0] = tf.identity(grads_all[0]) + + return grads_all, vars_all, logits, loss def _apply_weight_decay(self, grads, vars_): """Update gradients to reflect weight decay.""" @@ -254,17 +217,37 @@ class RevNet(tf.keras.Model): ] def get_moving_stats(self): + """Get moving averages of batch normalization. + + This is needed to avoid updating the running average twice in one iteration. + + Returns: + A dictionary mapping variables for batch normalization moving averages + to their current values. + """ vars_and_vals = {} def _is_moving_var(v): n = v.name return n.endswith("moving_mean:0") or n.endswith("moving_variance:0") - for v in filter(_is_moving_var, self.variables): - vars_and_vals[v] = v.read_value() + device = "/gpu:0" if tf.test.is_gpu_available() else "/cpu:0" + with tf.device(device): + for v in filter(_is_moving_var, self.variables): + vars_and_vals[v] = v.read_value() return vars_and_vals def restore_moving_stats(self, vars_and_vals): - for var_, val in six.iteritems(vars_and_vals): - var_.assign(val) + """Restore moving averages of batch normalization. + + This is needed to avoid updating the running average twice in one iteration. + + Args: + vars_and_vals: The dictionary mapping variables to their previous values. + """ + device = "/gpu:0" if tf.test.is_gpu_available() else "/cpu:0" + with tf.device(device): + for var_, val in six.iteritems(vars_and_vals): + # `assign` causes a copy to GPU (if variable is already on GPU) + var_.assign(val) |