aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Chris Ying <chrisying@google.com>2017-09-28 11:05:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-28 11:09:20 -0700
commit125f7afa4a483855dc75791445d2dea64587876a (patch)
treebc9f097825e600b1e84b194577711f95ccf584dc
parentd3d60ff6acec178b1cf912938aa6180bbd1a676f (diff)
Implementing ghost batch norm as defined in https://arxiv.org/pdf/1705.08741.
Reuses most of tf.layers.batch_normalization's existing functionality by using some reshaping and transposing tricks. Toggled via additional optional parameter `num_virtual_batches`. Ghost batch norm is essential for large batch training where the true batch size is different than the batch norm batch size. PiperOrigin-RevId: 170368495
-rw-r--r--tensorflow/python/layers/normalization.py92
-rw-r--r--tensorflow/python/layers/normalization_test.py195
-rw-r--r--tensorflow/tools/api/golden/tensorflow.layers.-batch-normalization.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.layers.pbtxt2
4 files changed, 279 insertions, 12 deletions
diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py
index f9fe7b34bb..bcdb67ae90 100644
--- a/tensorflow/python/layers/normalization.py
+++ b/tensorflow/python/layers/normalization.py
@@ -49,7 +49,7 @@ class BatchNormalization(base.Layer):
Sergey Ioffe, Christian Szegedy
Arguments:
- axis: Integer, the axis that should be normalized (typically the features
+ axis: An `int`, the axis that should be normalized (typically the features
axis). For instance, after a `Conv2D` layer with
`data_format="channels_first"`, set `axis=1` in `BatchNormalization`.
momentum: Momentum for the moving average.
@@ -90,6 +90,11 @@ class BatchNormalization(base.Layer):
If `None`, use the system recommended implementation.
trainable: Boolean, if `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
+ num_virtual_batches: An `int`, specifies the number of virtual batches to
+ operate over. If not greater than 1, will perform "ghost batch
+ normalization", which creates virtual sub-batches to operate over for
+ batch norm. Default is 1 virtual batch, in which no virtual batching is
+ performed. Must divide the actual batch size during graph execution.
name: A string, the name of the layer.
"""
@@ -112,6 +117,7 @@ class BatchNormalization(base.Layer):
renorm_momentum=0.99,
fused=None,
trainable=True,
+ num_virtual_batches=1,
name=None,
**kwargs):
super(BatchNormalization, self).__init__(
@@ -135,6 +141,11 @@ class BatchNormalization(base.Layer):
self.fused = fused
self._bessels_correction_test_only = True
+
+ if num_virtual_batches < 1:
+ raise ValueError('num_virtual_batches must be a positive integer')
+ self.num_virtual_batches = num_virtual_batches
+
if renorm:
renorm_clipping = renorm_clipping or {}
keys = ['rmax', 'rmin', 'dmax']
@@ -180,6 +191,10 @@ class BatchNormalization(base.Layer):
self.input_spec = base.InputSpec(ndim=ndim,
axes={self.axis: param_dim.value})
+ if self.num_virtual_batches > 1:
+ # the axis dim is combined with num_virtual_batches
+ param_dim = input_shape[axis] * self.num_virtual_batches
+
if self.scale:
self.gamma = self.add_variable(name='gamma',
shape=(param_dim,),
@@ -389,8 +404,53 @@ class BatchNormalization(base.Layer):
return (r, d, new_mean, new_variance)
def call(self, inputs, training=False):
+ if self.num_virtual_batches > 1:
+ # Virtual batches (aka ghost batches) can be simulated by using some
+ # reshape/transpose tricks on top of base batch normalization.
+ original_shape = [-1] + inputs.shape.as_list()[1:]
+ expanded_shape = [-1, self.num_virtual_batches] + original_shape[1:]
+
+ # Will cause errors if num_virtual_batches does not divide the batch size
+ inputs = array_ops.reshape(inputs, expanded_shape)
+
+ ndims = len(expanded_shape)
+ if self.axis < 0:
+ axis = ndims + self.axis
+ else:
+ axis = self.axis + 1 # Account for the added dimension
+
+ # Permute the num_virtual_batch dimension (dim 1) to be adjacent to axis
+ # TODO(b/66257056): when multi-axis batch normalization is implemented,
+ # this permutation trick and the combined_dim reshape are no longer
+ # necessary and can be reworked to simply use broadcasting.
+ permutation = ([0] + list(range(2, axis)) + [1, axis] +
+ list(range(axis + 1, ndims)))
+ inverse_permutation = [x[1] for x in
+ sorted(zip(permutation, range(ndims)))]
+ inputs = array_ops.transpose(inputs, perm=permutation)
+
+ # Combine the axis and num_virtual_batch dimension in order to take
+ # advantage of fused batch normalization
+ combined_dim = expanded_shape[1] * expanded_shape[axis]
+ perm_shape = [-1] + inputs.shape.as_list()[1:]
+ combined_shape = (perm_shape[:axis - 1] +
+ [combined_dim] +
+ perm_shape[axis + 1:])
+ inputs = array_ops.reshape(inputs, combined_shape)
+ # After the above reshape, the batch norm axis is the original self.axis
+
+ # Undoes the reshaping and transposing tricks done above
+ def undo_virtual_batching(outputs):
+ outputs = array_ops.reshape(outputs, perm_shape)
+ outputs = array_ops.transpose(outputs, perm=inverse_permutation)
+ outputs = array_ops.reshape(outputs, original_shape)
+ return outputs
+
if self.fused:
- return self._fused_batch_norm(inputs, training=training)
+ outputs = self._fused_batch_norm(inputs, training=training)
+ if self.num_virtual_batches > 1:
+ return undo_virtual_batching(outputs)
+ return outputs
# First, compute the axes along which to reduce the mean / variance,
# as well as the broadcast shape to be used for all parameters.
@@ -454,12 +514,17 @@ class BatchNormalization(base.Layer):
return array_ops.reshape(v, broadcast_shape)
return v
- return nn.batch_normalization(inputs,
- _broadcast(mean),
- _broadcast(variance),
- _broadcast(offset),
- _broadcast(scale),
- self.epsilon)
+ outputs = nn.batch_normalization(inputs,
+ _broadcast(mean),
+ _broadcast(variance),
+ _broadcast(offset),
+ _broadcast(scale),
+ self.epsilon)
+
+ if self.num_virtual_batches > 1:
+ return undo_virtual_batching(outputs)
+
+ return outputs
def batch_normalization(inputs,
@@ -483,7 +548,8 @@ def batch_normalization(inputs,
renorm=False,
renorm_clipping=None,
renorm_momentum=0.99,
- fused=None):
+ fused=None,
+ num_virtual_batches=1):
"""Functional interface for the batch normalization layer.
Reference: http://arxiv.org/abs/1502.03167
@@ -505,7 +571,7 @@ def batch_normalization(inputs,
Arguments:
inputs: Tensor input.
- axis: Integer, the axis that should be normalized (typically the features
+ axis: An `int`, the axis that should be normalized (typically the features
axis). For instance, after a `Convolution2D` layer with
`data_format="channels_first"`, set `axis=1` in `BatchNormalization`.
momentum: Momentum for the moving average.
@@ -555,6 +621,11 @@ def batch_normalization(inputs,
to get the means and variances for inference.
fused: if `True`, use a faster, fused implementation if possible.
If `None`, use the system recommended implementation.
+ num_virtual_batches: An `int`, specifies the number of virtual batches to
+ operate over. If greater than 1, will perform "ghost batch
+ normalization", which creates virtual sub-batches to operate over for
+ batch norm. Default is 1 virtual batch, in which no virtual batching is
+ performed. Must divide the actual batch size during graph execution.
Returns:
Output tensor.
@@ -578,6 +649,7 @@ def batch_normalization(inputs,
renorm_momentum=renorm_momentum,
fused=fused,
trainable=trainable,
+ num_virtual_batches=num_virtual_batches,
name=name,
_reuse=reuse,
_scope=name)
diff --git a/tensorflow/python/layers/normalization_test.py b/tensorflow/python/layers/normalization_test.py
index 3dc6a33b44..ccb0662c4e 100644
--- a/tensorflow/python/layers/normalization_test.py
+++ b/tensorflow/python/layers/normalization_test.py
@@ -823,6 +823,201 @@ class BNTest(test.TestCase):
self.assertAllClose(y_train, yt_val_train, atol=1e-5)
self.assertAllClose(y_test, yt_val_test, atol=1e-5)
+ def testGhostBNVirtualBatch1(self):
+ shape = [6, 5, 4, 3]
+ inp = random_ops.random_uniform(shape, seed=1)
+ out1 = normalization_layers.batch_normalization(inp)
+ out2 = normalization_layers.batch_normalization(
+ inp, num_virtual_batches=1)
+
+ self.assertListEqual(
+ out1.shape.as_list(), out2.shape.as_list())
+
+ with self.test_session(use_gpu=True) as sess:
+ sess.run(variables.global_variables_initializer())
+
+ x = np.random.random(shape)
+ y1, y2 = sess.run([out1, out2], feed_dict={inp: x})
+
+ self.assertAllClose(y1, y2, atol=1e-5)
+
+ def testGhostBNNegativeVirtualBatch(self):
+ shape = [6, 5, 4, 3]
+ inp = random_ops.random_uniform(shape, seed=1)
+
+ with self.assertRaises(ValueError):
+ normalization_layers.batch_normalization(
+ inp, num_virtual_batches=-1)
+
+ def testGhostBNInputOutputShapesMatch(self):
+ shape = [6, 4, 3]
+ inp = random_ops.random_uniform(shape, seed=1)
+ out = normalization_layers.batch_normalization(
+ inp, num_virtual_batches=2)
+ self.assertListEqual(out.shape.as_list(), shape)
+
+ def testGhostBNUnknownBatchSize(self):
+ np_shape = [10, 5, 4]
+ tf_shape = [None, 5, 4]
+ inp = array_ops.placeholder(dtypes.float32, tf_shape)
+ out = normalization_layers.batch_normalization(
+ inp, num_virtual_batches=5)
+
+ with self.test_session(use_gpu=True) as sess:
+ sess.run(variables.global_variables_initializer())
+
+ x = np.random.random(np_shape)
+ y = sess.run(out, feed_dict={inp: x})
+
+ self.assertListEqual(list(y.shape), np_shape)
+
+ def testGhostBN2Dims(self):
+ shape = [6, 2]
+ num_virtual_batches = 2
+ beta = 2.
+ gamma = 3.
+ momentum = 0.8
+ epsilon = 1e-3
+ moving_means = np.zeros([2, 2], dtype=np.float32)
+ moving_vars = np.ones([2, 2], dtype=np.float32)
+
+ inp = array_ops.placeholder(dtypes.float32, shape)
+ is_training = array_ops.placeholder(dtypes.bool)
+ bn = normalization_layers.BatchNormalization(
+ momentum=momentum,
+ epsilon=epsilon,
+ beta_initializer=init_ops.constant_initializer(beta),
+ gamma_initializer=init_ops.constant_initializer(gamma),
+ num_virtual_batches=num_virtual_batches)
+ out = bn.apply(inp, training=is_training)
+ ghost_shape = ([shape[0] // num_virtual_batches,
+ num_virtual_batches, shape[1]])
+
+ with self.test_session(use_gpu=True) as sess:
+ sess.run(variables.global_variables_initializer())
+ for _ in range(5):
+ x = np.random.random(shape)
+
+ sub_batched = np.reshape(x, ghost_shape)
+ means = np.mean(sub_batched, axis=0)
+ variances = np.var(sub_batched, axis=0)
+ moving_means = moving_means * momentum + means * (1. - momentum)
+ moving_vars = moving_vars * momentum + variances * (1. - momentum)
+
+ y_train = ((sub_batched - means) /
+ (variances + epsilon) ** 0.5 * gamma) + beta
+ y_test = ((sub_batched - moving_means) /
+ (moving_vars + epsilon) ** 0.5 * gamma) + beta
+
+ y_train = np.reshape(y_train, shape)
+ y_test = np.reshape(y_test, shape)
+
+ y_val_train, _, _ = sess.run([out] + bn.updates,
+ feed_dict={inp: x, is_training: True})
+ y_val_test = sess.run(out, feed_dict={inp: x, is_training: False})
+
+ self.assertAllClose(y_train, y_val_train, atol=1e-5)
+ self.assertAllClose(y_test, y_val_test, atol=1e-5)
+
+ def testGhostBN4DimsAxis3(self):
+ shape = [6, 10, 10, 3]
+ num_virtual_batches = 3
+ beta = 2.
+ gamma = 3.
+ momentum = 0.8
+ epsilon = 1e-3
+ moving_means = np.zeros([1, 3, 1, 1, 3], dtype=np.float32)
+ moving_vars = np.ones([1, 3, 1, 1, 3], dtype=np.float32)
+
+ inp = array_ops.placeholder(dtypes.float32, shape)
+ is_training = array_ops.placeholder(dtypes.bool)
+ bn = normalization_layers.BatchNormalization(
+ axis=3,
+ momentum=momentum,
+ epsilon=epsilon,
+ beta_initializer=init_ops.constant_initializer(beta),
+ gamma_initializer=init_ops.constant_initializer(gamma),
+ num_virtual_batches=num_virtual_batches)
+ out = bn.apply(inp, training=is_training)
+ ghost_shape = ([shape[0] // num_virtual_batches, num_virtual_batches] +
+ shape[1:])
+
+ with self.test_session(use_gpu=True) as sess:
+ sess.run(variables.global_variables_initializer())
+ for _ in range(5):
+ x = np.random.random(shape)
+
+ sub_batched = np.reshape(x, ghost_shape)
+ means = np.mean(sub_batched, axis=(0, 2, 3), keepdims=True)
+ variances = np.var(sub_batched, axis=(0, 2, 3), keepdims=True)
+ moving_means = moving_means * momentum + means * (1. - momentum)
+ moving_vars = moving_vars * momentum + variances * (1. - momentum)
+
+ y_train = ((sub_batched - means) /
+ (variances + epsilon) ** 0.5 * gamma) + beta
+ y_test = ((sub_batched - moving_means) /
+ (moving_vars + epsilon) ** 0.5 * gamma) + beta
+
+ y_train = np.reshape(y_train, shape)
+ y_test = np.reshape(y_test, shape)
+
+ y_val_train, _, _ = sess.run([out] + bn.updates,
+ feed_dict={inp: x, is_training: True})
+ y_val_test = sess.run(out, feed_dict={inp: x, is_training: False})
+
+ self.assertAllClose(y_train, y_val_train, atol=1e-2)
+ self.assertAllClose(y_test, y_val_test, atol=1e-2)
+
+ def testGhostBN4DimsAxis1(self):
+ shape = [6, 3, 10, 10]
+ num_virtual_batches = 3
+ beta = 2.
+ gamma = 3.
+ momentum = 0.8
+ epsilon = 1e-3
+ moving_means = np.zeros([1, 3, 3, 1, 1], dtype=np.float32)
+ moving_vars = np.ones([1, 3, 3, 1, 1], dtype=np.float32)
+
+ inp = array_ops.placeholder(dtypes.float32, shape)
+ is_training = array_ops.placeholder(dtypes.bool)
+ bn = normalization_layers.BatchNormalization(
+ axis=1,
+ momentum=momentum,
+ epsilon=epsilon,
+ beta_initializer=init_ops.constant_initializer(beta),
+ gamma_initializer=init_ops.constant_initializer(gamma),
+ num_virtual_batches=num_virtual_batches,
+ fused=False) # NCHW is unsupported by CPU fused batch norm
+ out = bn.apply(inp, training=is_training)
+ ghost_shape = ([shape[0] // num_virtual_batches, num_virtual_batches] +
+ shape[1:])
+
+ with self.test_session(use_gpu=True) as sess:
+ sess.run(variables.global_variables_initializer())
+ for _ in range(5):
+ x = np.random.random(shape)
+
+ sub_batched = np.reshape(x, ghost_shape)
+ means = np.mean(sub_batched, axis=(0, 3, 4), keepdims=True)
+ variances = np.var(sub_batched, axis=(0, 3, 4), keepdims=True)
+ moving_means = moving_means * momentum + means * (1. - momentum)
+ moving_vars = moving_vars * momentum + variances * (1. - momentum)
+
+ y_train = ((sub_batched - means) /
+ (variances + epsilon) ** 0.5 * gamma) + beta
+ y_test = ((sub_batched - moving_means) /
+ (moving_vars + epsilon) ** 0.5 * gamma) + beta
+
+ y_train = np.reshape(y_train, shape)
+ y_test = np.reshape(y_test, shape)
+
+ y_val_train, _, _ = sess.run([out] + bn.updates,
+ feed_dict={inp: x, is_training: True})
+ y_val_test = sess.run(out, feed_dict={inp: x, is_training: False})
+
+ self.assertAllClose(y_train, y_val_train, atol=1e-2)
+ self.assertAllClose(y_test, y_val_test, atol=1e-2)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-batch-normalization.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-batch-normalization.pbtxt
index 67d945a6ed..8417e0c347 100644
--- a/tensorflow/tools/api/golden/tensorflow.layers.-batch-normalization.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.layers.-batch-normalization.pbtxt
@@ -65,7 +65,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'axis\', \'momentum\', \'epsilon\', \'center\', \'scale\', \'beta_initializer\', \'gamma_initializer\', \'moving_mean_initializer\', \'moving_variance_initializer\', \'beta_regularizer\', \'gamma_regularizer\', \'beta_constraint\', \'gamma_constraint\', \'renorm\', \'renorm_clipping\', \'renorm_momentum\', \'fused\', \'trainable\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'-1\', \'0.99\', \'0.001\', \'True\', \'True\', \'<tensorflow.python.ops.init_ops.Zeros object instance>\', \'<tensorflow.python.ops.init_ops.Ones object instance>\', \'<tensorflow.python.ops.init_ops.Zeros object instance>\', \'<tensorflow.python.ops.init_ops.Ones object instance>\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'0.99\', \'None\', \'True\', \'None\'], "
+ argspec: "args=[\'self\', \'axis\', \'momentum\', \'epsilon\', \'center\', \'scale\', \'beta_initializer\', \'gamma_initializer\', \'moving_mean_initializer\', \'moving_variance_initializer\', \'beta_regularizer\', \'gamma_regularizer\', \'beta_constraint\', \'gamma_constraint\', \'renorm\', \'renorm_clipping\', \'renorm_momentum\', \'fused\', \'trainable\', \'num_virtual_batches\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'-1\', \'0.99\', \'0.001\', \'True\', \'True\', \'<tensorflow.python.ops.init_ops.Zeros object instance>\', \'<tensorflow.python.ops.init_ops.Ones object instance>\', \'<tensorflow.python.ops.init_ops.Zeros object instance>\', \'<tensorflow.python.ops.init_ops.Ones object instance>\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'0.99\', \'None\', \'True\', \'1\', \'None\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/tensorflow.layers.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.pbtxt
index f6d43d4c55..1176b17c9d 100644
--- a/tensorflow/tools/api/golden/tensorflow.layers.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.layers.pbtxt
@@ -94,7 +94,7 @@ tf_module {
}
member_method {
name: "batch_normalization"
- argspec: "args=[\'inputs\', \'axis\', \'momentum\', \'epsilon\', \'center\', \'scale\', \'beta_initializer\', \'gamma_initializer\', \'moving_mean_initializer\', \'moving_variance_initializer\', \'beta_regularizer\', \'gamma_regularizer\', \'beta_constraint\', \'gamma_constraint\', \'training\', \'trainable\', \'name\', \'reuse\', \'renorm\', \'renorm_clipping\', \'renorm_momentum\', \'fused\'], varargs=None, keywords=None, defaults=[\'-1\', \'0.99\', \'0.001\', \'True\', \'True\', \'<tensorflow.python.ops.init_ops.Zeros object instance>\', \'<tensorflow.python.ops.init_ops.Ones object instance>\', \'<tensorflow.python.ops.init_ops.Zeros object instance>\', \'<tensorflow.python.ops.init_ops.Ones object instance>\', \'None\', \'None\', \'None\', \'None\', \'False\', \'True\', \'None\', \'None\', \'False\', \'None\', \'0.99\', \'None\'], "
+ argspec: "args=[\'inputs\', \'axis\', \'momentum\', \'epsilon\', \'center\', \'scale\', \'beta_initializer\', \'gamma_initializer\', \'moving_mean_initializer\', \'moving_variance_initializer\', \'beta_regularizer\', \'gamma_regularizer\', \'beta_constraint\', \'gamma_constraint\', \'training\', \'trainable\', \'name\', \'reuse\', \'renorm\', \'renorm_clipping\', \'renorm_momentum\', \'fused\', \'num_virtual_batches\'], varargs=None, keywords=None, defaults=[\'-1\', \'0.99\', \'0.001\', \'True\', \'True\', \'<tensorflow.python.ops.init_ops.Zeros object instance>\', \'<tensorflow.python.ops.init_ops.Ones object instance>\', \'<tensorflow.python.ops.init_ops.Zeros object instance>\', \'<tensorflow.python.ops.init_ops.Ones object instance>\', \'None\', \'None\', \'None\', \'None\', \'False\', \'True\', \'None\', \'None\', \'False\', \'None\', \'0.99\', \'None\', \'1\'], "
}
member_method {
name: "conv1d"