aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/eager/python/examples/revnet/revnet.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/eager/python/examples/revnet/revnet.py')
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/revnet.py149
1 files changed, 66 insertions, 83 deletions
diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet.py b/tensorflow/contrib/eager/python/examples/revnet/revnet.py
index 0228bff6fa..b1cb312b74 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/revnet.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/revnet.py
@@ -24,9 +24,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import functools
-import operator
-
import six
import tensorflow as tf
from tensorflow.contrib.eager.python.examples.revnet import blocks
@@ -45,66 +42,9 @@ class RevNet(tf.keras.Model):
self.axis = 1 if config.data_format == "channels_first" else 3
self.config = config
- self._init_block = self._construct_init_block()
+ self._init_block = blocks.InitBlock(config=self.config)
+ self._final_block = blocks.FinalBlock(config=self.config)
self._block_list = self._construct_intermediate_blocks()
- self._final_block = self._construct_final_block()
- self._moving_stats_vars = None
-
- def _construct_init_block(self):
- init_block = tf.keras.Sequential(
- [
- tf.keras.layers.Conv2D(
- filters=self.config.init_filters,
- kernel_size=self.config.init_kernel,
- strides=(self.config.init_stride, self.config.init_stride),
- data_format=self.config.data_format,
- use_bias=False,
- padding="SAME",
- input_shape=self.config.input_shape),
- tf.keras.layers.BatchNormalization(
- axis=self.axis, fused=self.config.fused),
- tf.keras.layers.Activation("relu"),
- ],
- name="init")
- if self.config.init_max_pool:
- init_block.add(
- tf.keras.layers.MaxPooling2D(
- pool_size=(3, 3),
- strides=(2, 2),
- padding="SAME",
- data_format=self.config.data_format))
- return init_block
-
- def _construct_final_block(self):
- f = self.config.filters[-1] # Number of filters
- r = functools.reduce(operator.mul, self.config.strides, 1) # Reduce ratio
- r *= self.config.init_stride
- if self.config.init_max_pool:
- r *= 2
-
- if self.config.data_format == "channels_first":
- w, h = self.config.input_shape[1], self.config.input_shape[2]
- input_shape = (f, w // r, h // r)
- elif self.config.data_format == "channels_last":
- w, h = self.config.input_shape[0], self.config.input_shape[1]
- input_shape = (w // r, h // r, f)
- else:
- raise ValueError("Data format should be either `channels_first`"
- " or `channels_last`")
-
- final_block = tf.keras.Sequential(
- [
- tf.keras.layers.BatchNormalization(
- axis=self.axis,
- input_shape=input_shape,
- fused=self.config.fused),
- tf.keras.layers.Activation("relu"),
- tf.keras.layers.GlobalAveragePooling2D(
- data_format=self.config.data_format),
- tf.keras.layers.Dense(self.config.n_classes)
- ],
- name="final")
- return final_block
def _construct_intermediate_blocks(self):
# Precompute input shape after initial block
@@ -139,7 +79,8 @@ class RevNet(tf.keras.Model):
batch_norm_first=(i != 0), # Only skip on first block
data_format=self.config.data_format,
bottleneck=self.config.bottleneck,
- fused=self.config.fused)
+ fused=self.config.fused,
+ dtype=self.config.dtype)
block_list.append(rev_block)
# Precompute input shape for the next block
@@ -174,30 +115,46 @@ class RevNet(tf.keras.Model):
def compute_loss(self, logits, labels):
"""Compute cross entropy loss."""
- cross_ent = tf.nn.sparse_softmax_cross_entropy_with_logits(
- logits=logits, labels=labels)
+ if self.config.dtype == tf.float32 or self.config.dtype == tf.float16:
+ cross_ent = tf.nn.sparse_softmax_cross_entropy_with_logits(
+ logits=logits, labels=labels)
+ else:
+ # `sparse_softmax_cross_entropy_with_logits` does not have a GPU kernel
+ # for float64, int32 pairs
+ labels = tf.one_hot(
+ labels, depth=self.config.n_classes, axis=1, dtype=self.config.dtype)
+ cross_ent = tf.nn.softmax_cross_entropy_with_logits(
+ logits=logits, labels=labels)
return tf.reduce_mean(cross_ent)
- def compute_gradients(self, inputs, labels, training=True):
+ def compute_gradients(self, inputs, labels, training=True, l2_reg=True):
"""Manually computes gradients.
- This method also SILENTLY updates the running averages of batch
- normalization when `training` is set to True.
+ When eager execution is enabled, this method also SILENTLY updates the
+ running averages of batch normalization when `training` is set to True.
Args:
inputs: Image tensor, either NHWC or NCHW, conforming to `data_format`
labels: One-hot labels for classification
training: Use the mini-batch stats in batch norm if set to True
+ l2_reg: Apply l2 regularization
Returns:
- list of tuples each being (grad, var) for optimizer to use
+ A tuple with the first entry being a list of all gradients, the second
+ entry being a list of respective variables, the third being the logits,
+ and the forth being the loss
"""
- # Run forward pass to record hidden states; avoid updating running averages
+ # Run forward pass to record hidden states
vars_and_vals = self.get_moving_stats()
- _, saved_hidden = self.call(inputs, training=training)
- self.restore_moving_stats(vars_and_vals)
+ _, saved_hidden = self(inputs, training=training) # pylint:disable=not-callable
+ if tf.executing_eagerly():
+ # Restore moving averages when executing eagerly to avoid updating twice
+ self.restore_moving_stats(vars_and_vals)
+ else:
+ # Fetch batch norm updates in graph mode
+ updates = self.get_updates_for(inputs)
grads_all = []
vars_all = []
@@ -205,9 +162,8 @@ class RevNet(tf.keras.Model):
# Manually backprop through last block
x = saved_hidden[-1]
with tf.GradientTape() as tape:
- x = tf.identity(x) # TODO(lxuechen): Remove after b/110264016 is fixed
tape.watch(x)
- # Running stats updated below
+ # Running stats updated here
logits = self._final_block(x, training=training)
loss = self.compute_loss(logits, labels)
@@ -221,6 +177,7 @@ class RevNet(tf.keras.Model):
for block in reversed(self._block_list):
y = saved_hidden.pop()
x = saved_hidden[-1]
+ # Running stats updated here
dy, grads, vars_ = block.backward_grads_and_vars(
x, y, dy, training=training)
grads_all += grads
@@ -232,18 +189,24 @@ class RevNet(tf.keras.Model):
assert not saved_hidden # Cleared after backprop
with tf.GradientTape() as tape:
- x = tf.identity(x) # TODO(lxuechen): Remove after b/110264016 is fixed
- # Running stats updated below
+ # Running stats updated here
y = self._init_block(x, training=training)
grads_all += tape.gradient(
- y, self._init_block.trainable_variables, output_gradients=[dy])
+ y, self._init_block.trainable_variables, output_gradients=dy)
vars_all += self._init_block.trainable_variables
# Apply weight decay
- grads_all = self._apply_weight_decay(grads_all, vars_all)
+ if l2_reg:
+ grads_all = self._apply_weight_decay(grads_all, vars_all)
- return grads_all, vars_all, loss
+ if not tf.executing_eagerly():
+ # Force updates to be executed before gradient computation in graph mode
+ # This does nothing when the function is wrapped in defun
+ with tf.control_dependencies(updates):
+ grads_all[0] = tf.identity(grads_all[0])
+
+ return grads_all, vars_all, logits, loss
def _apply_weight_decay(self, grads, vars_):
"""Update gradients to reflect weight decay."""
@@ -254,17 +217,37 @@ class RevNet(tf.keras.Model):
]
def get_moving_stats(self):
+ """Get moving averages of batch normalization.
+
+ This is needed to avoid updating the running average twice in one iteration.
+
+ Returns:
+ A dictionary mapping variables for batch normalization moving averages
+ to their current values.
+ """
vars_and_vals = {}
def _is_moving_var(v):
n = v.name
return n.endswith("moving_mean:0") or n.endswith("moving_variance:0")
- for v in filter(_is_moving_var, self.variables):
- vars_and_vals[v] = v.read_value()
+ device = "/gpu:0" if tf.test.is_gpu_available() else "/cpu:0"
+ with tf.device(device):
+ for v in filter(_is_moving_var, self.variables):
+ vars_and_vals[v] = v.read_value()
return vars_and_vals
def restore_moving_stats(self, vars_and_vals):
- for var_, val in six.iteritems(vars_and_vals):
- var_.assign(val)
+ """Restore moving averages of batch normalization.
+
+ This is needed to avoid updating the running average twice in one iteration.
+
+ Args:
+ vars_and_vals: The dictionary mapping variables to their previous values.
+ """
+ device = "/gpu:0" if tf.test.is_gpu_available() else "/cpu:0"
+ with tf.device(device):
+ for var_, val in six.iteritems(vars_and_vals):
+ # `assign` causes a copy to GPU (if variable is already on GPU)
+ var_.assign(val)