diff options
author | Chris Ying <chrisying@google.com> | 2017-09-28 11:05:39 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-09-28 11:09:20 -0700 |
commit | 125f7afa4a483855dc75791445d2dea64587876a (patch) | |
tree | bc9f097825e600b1e84b194577711f95ccf584dc | |
parent | d3d60ff6acec178b1cf912938aa6180bbd1a676f (diff) |
Implementing ghost batch norm as defined in https://arxiv.org/pdf/1705.08741.
Reuses most of tf.layers.batch_normalization's existing functionality by using some reshaping and transposing tricks. Toggled via additional optional parameter `num_virtual_batches`.
Ghost batch norm is essential for large batch training where the true batch
size is different than the batch norm batch size.
PiperOrigin-RevId: 170368495
4 files changed, 279 insertions, 12 deletions
diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py index f9fe7b34bb..bcdb67ae90 100644 --- a/tensorflow/python/layers/normalization.py +++ b/tensorflow/python/layers/normalization.py @@ -49,7 +49,7 @@ class BatchNormalization(base.Layer): Sergey Ioffe, Christian Szegedy Arguments: - axis: Integer, the axis that should be normalized (typically the features + axis: An `int`, the axis that should be normalized (typically the features axis). For instance, after a `Conv2D` layer with `data_format="channels_first"`, set `axis=1` in `BatchNormalization`. momentum: Momentum for the moving average. @@ -90,6 +90,11 @@ class BatchNormalization(base.Layer): If `None`, use the system recommended implementation. trainable: Boolean, if `True` also add variables to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). + num_virtual_batches: An `int`, specifies the number of virtual batches to + operate over. If not greater than 1, will perform "ghost batch + normalization", which creates virtual sub-batches to operate over for + batch norm. Default is 1 virtual batch, in which no virtual batching is + performed. Must divide the actual batch size during graph execution. name: A string, the name of the layer. """ @@ -112,6 +117,7 @@ class BatchNormalization(base.Layer): renorm_momentum=0.99, fused=None, trainable=True, + num_virtual_batches=1, name=None, **kwargs): super(BatchNormalization, self).__init__( @@ -135,6 +141,11 @@ class BatchNormalization(base.Layer): self.fused = fused self._bessels_correction_test_only = True + + if num_virtual_batches < 1: + raise ValueError('num_virtual_batches must be a positive integer') + self.num_virtual_batches = num_virtual_batches + if renorm: renorm_clipping = renorm_clipping or {} keys = ['rmax', 'rmin', 'dmax'] @@ -180,6 +191,10 @@ class BatchNormalization(base.Layer): self.input_spec = base.InputSpec(ndim=ndim, axes={self.axis: param_dim.value}) + if self.num_virtual_batches > 1: + # the axis dim is combined with num_virtual_batches + param_dim = input_shape[axis] * self.num_virtual_batches + if self.scale: self.gamma = self.add_variable(name='gamma', shape=(param_dim,), @@ -389,8 +404,53 @@ class BatchNormalization(base.Layer): return (r, d, new_mean, new_variance) def call(self, inputs, training=False): + if self.num_virtual_batches > 1: + # Virtual batches (aka ghost batches) can be simulated by using some + # reshape/transpose tricks on top of base batch normalization. + original_shape = [-1] + inputs.shape.as_list()[1:] + expanded_shape = [-1, self.num_virtual_batches] + original_shape[1:] + + # Will cause errors if num_virtual_batches does not divide the batch size + inputs = array_ops.reshape(inputs, expanded_shape) + + ndims = len(expanded_shape) + if self.axis < 0: + axis = ndims + self.axis + else: + axis = self.axis + 1 # Account for the added dimension + + # Permute the num_virtual_batch dimension (dim 1) to be adjacent to axis + # TODO(b/66257056): when multi-axis batch normalization is implemented, + # this permutation trick and the combined_dim reshape are no longer + # necessary and can be reworked to simply use broadcasting. + permutation = ([0] + list(range(2, axis)) + [1, axis] + + list(range(axis + 1, ndims))) + inverse_permutation = [x[1] for x in + sorted(zip(permutation, range(ndims)))] + inputs = array_ops.transpose(inputs, perm=permutation) + + # Combine the axis and num_virtual_batch dimension in order to take + # advantage of fused batch normalization + combined_dim = expanded_shape[1] * expanded_shape[axis] + perm_shape = [-1] + inputs.shape.as_list()[1:] + combined_shape = (perm_shape[:axis - 1] + + [combined_dim] + + perm_shape[axis + 1:]) + inputs = array_ops.reshape(inputs, combined_shape) + # After the above reshape, the batch norm axis is the original self.axis + + # Undoes the reshaping and transposing tricks done above + def undo_virtual_batching(outputs): + outputs = array_ops.reshape(outputs, perm_shape) + outputs = array_ops.transpose(outputs, perm=inverse_permutation) + outputs = array_ops.reshape(outputs, original_shape) + return outputs + if self.fused: - return self._fused_batch_norm(inputs, training=training) + outputs = self._fused_batch_norm(inputs, training=training) + if self.num_virtual_batches > 1: + return undo_virtual_batching(outputs) + return outputs # First, compute the axes along which to reduce the mean / variance, # as well as the broadcast shape to be used for all parameters. @@ -454,12 +514,17 @@ class BatchNormalization(base.Layer): return array_ops.reshape(v, broadcast_shape) return v - return nn.batch_normalization(inputs, - _broadcast(mean), - _broadcast(variance), - _broadcast(offset), - _broadcast(scale), - self.epsilon) + outputs = nn.batch_normalization(inputs, + _broadcast(mean), + _broadcast(variance), + _broadcast(offset), + _broadcast(scale), + self.epsilon) + + if self.num_virtual_batches > 1: + return undo_virtual_batching(outputs) + + return outputs def batch_normalization(inputs, @@ -483,7 +548,8 @@ def batch_normalization(inputs, renorm=False, renorm_clipping=None, renorm_momentum=0.99, - fused=None): + fused=None, + num_virtual_batches=1): """Functional interface for the batch normalization layer. Reference: http://arxiv.org/abs/1502.03167 @@ -505,7 +571,7 @@ def batch_normalization(inputs, Arguments: inputs: Tensor input. - axis: Integer, the axis that should be normalized (typically the features + axis: An `int`, the axis that should be normalized (typically the features axis). For instance, after a `Convolution2D` layer with `data_format="channels_first"`, set `axis=1` in `BatchNormalization`. momentum: Momentum for the moving average. @@ -555,6 +621,11 @@ def batch_normalization(inputs, to get the means and variances for inference. fused: if `True`, use a faster, fused implementation if possible. If `None`, use the system recommended implementation. + num_virtual_batches: An `int`, specifies the number of virtual batches to + operate over. If greater than 1, will perform "ghost batch + normalization", which creates virtual sub-batches to operate over for + batch norm. Default is 1 virtual batch, in which no virtual batching is + performed. Must divide the actual batch size during graph execution. Returns: Output tensor. @@ -578,6 +649,7 @@ def batch_normalization(inputs, renorm_momentum=renorm_momentum, fused=fused, trainable=trainable, + num_virtual_batches=num_virtual_batches, name=name, _reuse=reuse, _scope=name) diff --git a/tensorflow/python/layers/normalization_test.py b/tensorflow/python/layers/normalization_test.py index 3dc6a33b44..ccb0662c4e 100644 --- a/tensorflow/python/layers/normalization_test.py +++ b/tensorflow/python/layers/normalization_test.py @@ -823,6 +823,201 @@ class BNTest(test.TestCase): self.assertAllClose(y_train, yt_val_train, atol=1e-5) self.assertAllClose(y_test, yt_val_test, atol=1e-5) + def testGhostBNVirtualBatch1(self): + shape = [6, 5, 4, 3] + inp = random_ops.random_uniform(shape, seed=1) + out1 = normalization_layers.batch_normalization(inp) + out2 = normalization_layers.batch_normalization( + inp, num_virtual_batches=1) + + self.assertListEqual( + out1.shape.as_list(), out2.shape.as_list()) + + with self.test_session(use_gpu=True) as sess: + sess.run(variables.global_variables_initializer()) + + x = np.random.random(shape) + y1, y2 = sess.run([out1, out2], feed_dict={inp: x}) + + self.assertAllClose(y1, y2, atol=1e-5) + + def testGhostBNNegativeVirtualBatch(self): + shape = [6, 5, 4, 3] + inp = random_ops.random_uniform(shape, seed=1) + + with self.assertRaises(ValueError): + normalization_layers.batch_normalization( + inp, num_virtual_batches=-1) + + def testGhostBNInputOutputShapesMatch(self): + shape = [6, 4, 3] + inp = random_ops.random_uniform(shape, seed=1) + out = normalization_layers.batch_normalization( + inp, num_virtual_batches=2) + self.assertListEqual(out.shape.as_list(), shape) + + def testGhostBNUnknownBatchSize(self): + np_shape = [10, 5, 4] + tf_shape = [None, 5, 4] + inp = array_ops.placeholder(dtypes.float32, tf_shape) + out = normalization_layers.batch_normalization( + inp, num_virtual_batches=5) + + with self.test_session(use_gpu=True) as sess: + sess.run(variables.global_variables_initializer()) + + x = np.random.random(np_shape) + y = sess.run(out, feed_dict={inp: x}) + + self.assertListEqual(list(y.shape), np_shape) + + def testGhostBN2Dims(self): + shape = [6, 2] + num_virtual_batches = 2 + beta = 2. + gamma = 3. + momentum = 0.8 + epsilon = 1e-3 + moving_means = np.zeros([2, 2], dtype=np.float32) + moving_vars = np.ones([2, 2], dtype=np.float32) + + inp = array_ops.placeholder(dtypes.float32, shape) + is_training = array_ops.placeholder(dtypes.bool) + bn = normalization_layers.BatchNormalization( + momentum=momentum, + epsilon=epsilon, + beta_initializer=init_ops.constant_initializer(beta), + gamma_initializer=init_ops.constant_initializer(gamma), + num_virtual_batches=num_virtual_batches) + out = bn.apply(inp, training=is_training) + ghost_shape = ([shape[0] // num_virtual_batches, + num_virtual_batches, shape[1]]) + + with self.test_session(use_gpu=True) as sess: + sess.run(variables.global_variables_initializer()) + for _ in range(5): + x = np.random.random(shape) + + sub_batched = np.reshape(x, ghost_shape) + means = np.mean(sub_batched, axis=0) + variances = np.var(sub_batched, axis=0) + moving_means = moving_means * momentum + means * (1. - momentum) + moving_vars = moving_vars * momentum + variances * (1. - momentum) + + y_train = ((sub_batched - means) / + (variances + epsilon) ** 0.5 * gamma) + beta + y_test = ((sub_batched - moving_means) / + (moving_vars + epsilon) ** 0.5 * gamma) + beta + + y_train = np.reshape(y_train, shape) + y_test = np.reshape(y_test, shape) + + y_val_train, _, _ = sess.run([out] + bn.updates, + feed_dict={inp: x, is_training: True}) + y_val_test = sess.run(out, feed_dict={inp: x, is_training: False}) + + self.assertAllClose(y_train, y_val_train, atol=1e-5) + self.assertAllClose(y_test, y_val_test, atol=1e-5) + + def testGhostBN4DimsAxis3(self): + shape = [6, 10, 10, 3] + num_virtual_batches = 3 + beta = 2. + gamma = 3. + momentum = 0.8 + epsilon = 1e-3 + moving_means = np.zeros([1, 3, 1, 1, 3], dtype=np.float32) + moving_vars = np.ones([1, 3, 1, 1, 3], dtype=np.float32) + + inp = array_ops.placeholder(dtypes.float32, shape) + is_training = array_ops.placeholder(dtypes.bool) + bn = normalization_layers.BatchNormalization( + axis=3, + momentum=momentum, + epsilon=epsilon, + beta_initializer=init_ops.constant_initializer(beta), + gamma_initializer=init_ops.constant_initializer(gamma), + num_virtual_batches=num_virtual_batches) + out = bn.apply(inp, training=is_training) + ghost_shape = ([shape[0] // num_virtual_batches, num_virtual_batches] + + shape[1:]) + + with self.test_session(use_gpu=True) as sess: + sess.run(variables.global_variables_initializer()) + for _ in range(5): + x = np.random.random(shape) + + sub_batched = np.reshape(x, ghost_shape) + means = np.mean(sub_batched, axis=(0, 2, 3), keepdims=True) + variances = np.var(sub_batched, axis=(0, 2, 3), keepdims=True) + moving_means = moving_means * momentum + means * (1. - momentum) + moving_vars = moving_vars * momentum + variances * (1. - momentum) + + y_train = ((sub_batched - means) / + (variances + epsilon) ** 0.5 * gamma) + beta + y_test = ((sub_batched - moving_means) / + (moving_vars + epsilon) ** 0.5 * gamma) + beta + + y_train = np.reshape(y_train, shape) + y_test = np.reshape(y_test, shape) + + y_val_train, _, _ = sess.run([out] + bn.updates, + feed_dict={inp: x, is_training: True}) + y_val_test = sess.run(out, feed_dict={inp: x, is_training: False}) + + self.assertAllClose(y_train, y_val_train, atol=1e-2) + self.assertAllClose(y_test, y_val_test, atol=1e-2) + + def testGhostBN4DimsAxis1(self): + shape = [6, 3, 10, 10] + num_virtual_batches = 3 + beta = 2. + gamma = 3. + momentum = 0.8 + epsilon = 1e-3 + moving_means = np.zeros([1, 3, 3, 1, 1], dtype=np.float32) + moving_vars = np.ones([1, 3, 3, 1, 1], dtype=np.float32) + + inp = array_ops.placeholder(dtypes.float32, shape) + is_training = array_ops.placeholder(dtypes.bool) + bn = normalization_layers.BatchNormalization( + axis=1, + momentum=momentum, + epsilon=epsilon, + beta_initializer=init_ops.constant_initializer(beta), + gamma_initializer=init_ops.constant_initializer(gamma), + num_virtual_batches=num_virtual_batches, + fused=False) # NCHW is unsupported by CPU fused batch norm + out = bn.apply(inp, training=is_training) + ghost_shape = ([shape[0] // num_virtual_batches, num_virtual_batches] + + shape[1:]) + + with self.test_session(use_gpu=True) as sess: + sess.run(variables.global_variables_initializer()) + for _ in range(5): + x = np.random.random(shape) + + sub_batched = np.reshape(x, ghost_shape) + means = np.mean(sub_batched, axis=(0, 3, 4), keepdims=True) + variances = np.var(sub_batched, axis=(0, 3, 4), keepdims=True) + moving_means = moving_means * momentum + means * (1. - momentum) + moving_vars = moving_vars * momentum + variances * (1. - momentum) + + y_train = ((sub_batched - means) / + (variances + epsilon) ** 0.5 * gamma) + beta + y_test = ((sub_batched - moving_means) / + (moving_vars + epsilon) ** 0.5 * gamma) + beta + + y_train = np.reshape(y_train, shape) + y_test = np.reshape(y_test, shape) + + y_val_train, _, _ = sess.run([out] + bn.updates, + feed_dict={inp: x, is_training: True}) + y_val_test = sess.run(out, feed_dict={inp: x, is_training: False}) + + self.assertAllClose(y_train, y_val_train, atol=1e-2) + self.assertAllClose(y_test, y_val_test, atol=1e-2) + if __name__ == '__main__': test.main() diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-batch-normalization.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-batch-normalization.pbtxt index 67d945a6ed..8417e0c347 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-batch-normalization.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.layers.-batch-normalization.pbtxt @@ -65,7 +65,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'axis\', \'momentum\', \'epsilon\', \'center\', \'scale\', \'beta_initializer\', \'gamma_initializer\', \'moving_mean_initializer\', \'moving_variance_initializer\', \'beta_regularizer\', \'gamma_regularizer\', \'beta_constraint\', \'gamma_constraint\', \'renorm\', \'renorm_clipping\', \'renorm_momentum\', \'fused\', \'trainable\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'-1\', \'0.99\', \'0.001\', \'True\', \'True\', \'<tensorflow.python.ops.init_ops.Zeros object instance>\', \'<tensorflow.python.ops.init_ops.Ones object instance>\', \'<tensorflow.python.ops.init_ops.Zeros object instance>\', \'<tensorflow.python.ops.init_ops.Ones object instance>\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'0.99\', \'None\', \'True\', \'None\'], " + argspec: "args=[\'self\', \'axis\', \'momentum\', \'epsilon\', \'center\', \'scale\', \'beta_initializer\', \'gamma_initializer\', \'moving_mean_initializer\', \'moving_variance_initializer\', \'beta_regularizer\', \'gamma_regularizer\', \'beta_constraint\', \'gamma_constraint\', \'renorm\', \'renorm_clipping\', \'renorm_momentum\', \'fused\', \'trainable\', \'num_virtual_batches\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'-1\', \'0.99\', \'0.001\', \'True\', \'True\', \'<tensorflow.python.ops.init_ops.Zeros object instance>\', \'<tensorflow.python.ops.init_ops.Ones object instance>\', \'<tensorflow.python.ops.init_ops.Zeros object instance>\', \'<tensorflow.python.ops.init_ops.Ones object instance>\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'0.99\', \'None\', \'True\', \'1\', \'None\'], " } member_method { name: "add_loss" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.pbtxt index f6d43d4c55..1176b17c9d 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.layers.pbtxt @@ -94,7 +94,7 @@ tf_module { } member_method { name: "batch_normalization" - argspec: "args=[\'inputs\', \'axis\', \'momentum\', \'epsilon\', \'center\', \'scale\', \'beta_initializer\', \'gamma_initializer\', \'moving_mean_initializer\', \'moving_variance_initializer\', \'beta_regularizer\', \'gamma_regularizer\', \'beta_constraint\', \'gamma_constraint\', \'training\', \'trainable\', \'name\', \'reuse\', \'renorm\', \'renorm_clipping\', \'renorm_momentum\', \'fused\'], varargs=None, keywords=None, defaults=[\'-1\', \'0.99\', \'0.001\', \'True\', \'True\', \'<tensorflow.python.ops.init_ops.Zeros object instance>\', \'<tensorflow.python.ops.init_ops.Ones object instance>\', \'<tensorflow.python.ops.init_ops.Zeros object instance>\', \'<tensorflow.python.ops.init_ops.Ones object instance>\', \'None\', \'None\', \'None\', \'None\', \'False\', \'True\', \'None\', \'None\', \'False\', \'None\', \'0.99\', \'None\'], " + argspec: "args=[\'inputs\', \'axis\', \'momentum\', \'epsilon\', \'center\', \'scale\', \'beta_initializer\', \'gamma_initializer\', \'moving_mean_initializer\', \'moving_variance_initializer\', \'beta_regularizer\', \'gamma_regularizer\', \'beta_constraint\', \'gamma_constraint\', \'training\', \'trainable\', \'name\', \'reuse\', \'renorm\', \'renorm_clipping\', \'renorm_momentum\', \'fused\', \'num_virtual_batches\'], varargs=None, keywords=None, defaults=[\'-1\', \'0.99\', \'0.001\', \'True\', \'True\', \'<tensorflow.python.ops.init_ops.Zeros object instance>\', \'<tensorflow.python.ops.init_ops.Ones object instance>\', \'<tensorflow.python.ops.init_ops.Zeros object instance>\', \'<tensorflow.python.ops.init_ops.Ones object instance>\', \'None\', \'None\', \'None\', \'None\', \'False\', \'True\', \'None\', \'None\', \'False\', \'None\', \'0.99\', \'None\', \'1\'], " } member_method { name: "conv1d" |