diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-16 20:57:28 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-16 21:01:34 -0700 |
commit | a1606d5e0f667fddd7f3f5705bda3aee5b3c2554 (patch) | |
tree | 76e2f1d35826962394a9b58905ba5095c7f35ebf /tensorflow/contrib/layers | |
parent | 3e96a135d9650c91e307d1d56c81e5a37078cada (diff) |
Add one pass algorithm option to calculate the mean and variance in group_norm
PiperOrigin-RevId: 209096783
Diffstat (limited to 'tensorflow/contrib/layers')
-rw-r--r-- | tensorflow/contrib/layers/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/contrib/layers/python/layers/normalization.py | 25 | ||||
-rw-r--r-- | tensorflow/contrib/layers/python/layers/normalization_test.py | 215 |
3 files changed, 179 insertions, 63 deletions
diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD index 7355a403ae..b4fe8cac74 100644 --- a/tensorflow/contrib/layers/BUILD +++ b/tensorflow/contrib/layers/BUILD @@ -185,7 +185,7 @@ py_test( py_test( name = "normalization_test", - size = "small", + size = "medium", srcs = ["python/layers/normalization_test.py"], srcs_version = "PY2AND3", tags = ["no_windows"], # TODO: needs investigation on Windows diff --git a/tensorflow/contrib/layers/python/layers/normalization.py b/tensorflow/contrib/layers/python/layers/normalization.py index c807ab0f2e..85bd549393 100644 --- a/tensorflow/contrib/layers/python/layers/normalization.py +++ b/tensorflow/contrib/layers/python/layers/normalization.py @@ -176,7 +176,8 @@ def group_norm(inputs, variables_collections=None, outputs_collections=None, trainable=True, - scope=None): + scope=None, + mean_close_to_zero=False): """Functional interface for the group normalization layer. Reference: https://arxiv.org/abs/1803.08494. @@ -222,6 +223,19 @@ def group_norm(inputs, trainable: If `True` also add variables to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). scope: Optional scope for `variable_scope`. + mean_close_to_zero: The mean of `input` before ReLU will be close to zero + when batch size >= 4k for Resnet-50 on TPU. If `True`, use + `nn.sufficient_statistics` and `nn.normalize_moments` to calculate the + variance. This is the same behavior as `fused` equals `True` in batch + normalization. If `False`, use `nn.moments` to calculate the variance. + When `mean` is close to zero, like 1e-4, use `mean` to calculate the + variance may have poor result due to repeated roundoff error and + denormalization in `mean`. When `mean` is large, like 1e2, + sum(`input`^2) is so large that only the high-order digits of the elements + are being accumulated. Thus, use sum(`input` - `mean`)^2/n to calcualte + the variance has better accuracy compared to (sum(`input`^2)/n - `mean`^2) + when `mean` is large. + Returns: A `Tensor` representing the output of the operation. @@ -333,7 +347,14 @@ def group_norm(inputs, gamma = array_ops.reshape(gamma, params_shape_broadcast) # Calculate the moments. - mean, variance = nn.moments(inputs, moments_axes, keep_dims=True) + if mean_close_to_zero: + # One pass algorithm returns better result when mean is close to zero. + counts, means_ss, variance_ss, _ = nn.sufficient_statistics( + inputs, moments_axes, keep_dims=True) + mean, variance = nn.normalize_moments( + counts, means_ss, variance_ss, shift=None) + else: + mean, variance = nn.moments(inputs, moments_axes, keep_dims=True) # Compute normalization. # TODO(shlens): Fix nn.batch_normalization to handle the 5-D Tensor diff --git a/tensorflow/contrib/layers/python/layers/normalization_test.py b/tensorflow/contrib/layers/python/layers/normalization_test.py index b6e96350db..5d8c899eec 100644 --- a/tensorflow/contrib/layers/python/layers/normalization_test.py +++ b/tensorflow/contrib/layers/python/layers/normalization_test.py @@ -57,8 +57,7 @@ class InstanceNormTest(test.TestCase): images = random_ops.random_uniform((5, height, width, 3), seed=1) output = normalization.instance_norm(images) print('name: ', output.op.name) - self.assertStartsWith( - output.op.name, 'InstanceNorm/instancenorm') + self.assertStartsWith(output.op.name, 'InstanceNorm/instancenorm') self.assertListEqual([5, height, width, 3], output.shape.as_list()) def testCreateOpFloat64(self): @@ -66,8 +65,7 @@ class InstanceNormTest(test.TestCase): images = random_ops.random_uniform( (5, height, width, 3), dtype=dtypes.float64, seed=1) output = normalization.instance_norm(images) - self.assertStartsWith( - output.op.name, 'InstanceNorm/instancenorm') + self.assertStartsWith(output.op.name, 'InstanceNorm/instancenorm') self.assertListEqual([5, height, width, 3], output.shape.as_list()) def testCreateOpNoScaleCenter(self): @@ -75,8 +73,7 @@ class InstanceNormTest(test.TestCase): images = random_ops.random_uniform( (5, height, width, 3), dtype=dtypes.float64, seed=1) output = normalization.instance_norm(images, center=False, scale=False) - self.assertStartsWith( - output.op.name, 'InstanceNorm/instancenorm') + self.assertStartsWith(output.op.name, 'InstanceNorm/instancenorm') self.assertListEqual([5, height, width, 3], output.shape.as_list()) self.assertEqual(0, len(contrib_variables.get_variables_by_name('beta'))) self.assertEqual(0, len(contrib_variables.get_variables_by_name('gamma'))) @@ -173,24 +170,22 @@ class GroupNormTest(test.TestCase): inputs = array_ops.placeholder(dtypes.float32, shape=(5, 2, 10, 10)) with self.assertRaisesRegexp(ValueError, 'Invalid groups 10 for 2 channels.'): - normalization.group_norm(inputs, groups=10, - reduction_axes=[-2, -1], channels_axis=-3) + normalization.group_norm( + inputs, groups=10, reduction_axes=[-2, -1], channels_axis=-3) def testBadCommensurateGroup(self): inputs = array_ops.placeholder(dtypes.float32, shape=(5, 4, 10, 10)) - with self.assertRaisesRegexp(ValueError, - '4 channels is not commensurate with ' - '3 groups.'): - normalization.group_norm(inputs, groups=3, - reduction_axes=[-2, -1], channels_axis=-3) + with self.assertRaisesRegexp( + ValueError, '4 channels is not commensurate with ' + '3 groups.'): + normalization.group_norm( + inputs, groups=3, reduction_axes=[-2, -1], channels_axis=-3) def testAxisIsBad(self): inputs = array_ops.placeholder(dtypes.float32, shape=(1, 2, 4, 5)) - with self.assertRaisesRegexp(ValueError, - 'Axis is out of bounds.'): + with self.assertRaisesRegexp(ValueError, 'Axis is out of bounds.'): normalization.group_norm(inputs, channels_axis=5) - with self.assertRaisesRegexp(ValueError, - 'Axis is out of bounds.'): + with self.assertRaisesRegexp(ValueError, 'Axis is out of bounds.'): normalization.group_norm(inputs, reduction_axes=[1, 5]) def testNotMutuallyExclusiveAxis(self): @@ -218,41 +213,45 @@ class GroupNormTest(test.TestCase): def testParamsShapeNotFullyDefinedChannelsAxis(self): inputs = array_ops.placeholder(dtypes.float32, shape=(1, 3, 4, None)) with self.assertRaisesRegexp(ValueError, 'undefined channel dimension'): - normalization.group_norm(inputs, channels_axis=-1, - reduction_axes=[-3, -2]) + normalization.group_norm( + inputs, channels_axis=-1, reduction_axes=[-3, -2]) def testCreateOp(self): height, width, groups = 3, 3, 4 - images = random_ops.random_uniform((5, height, width, 2*groups), seed=1) - output = normalization.group_norm(images, groups=groups, channels_axis=-1, - reduction_axes=[-3, -2]) + images = random_ops.random_uniform((5, height, width, 2 * groups), seed=1) + output = normalization.group_norm( + images, groups=groups, channels_axis=-1, reduction_axes=[-3, -2]) print('name: ', output.op.name) - self.assertListEqual([5, height, width, 2*groups], output.shape.as_list()) + self.assertListEqual([5, height, width, 2 * groups], output.shape.as_list()) def testCreateOpFloat64(self): height, width, groups = 3, 3, 5 images = random_ops.random_uniform( - (5, height, width, 4*groups), dtype=dtypes.float64, seed=1) + (5, height, width, 4 * groups), dtype=dtypes.float64, seed=1) output = normalization.group_norm(images, groups=groups) self.assertEqual(dtypes.float64, output.dtype) - self.assertListEqual([5, height, width, 4*groups], output.shape.as_list()) + self.assertListEqual([5, height, width, 4 * groups], output.shape.as_list()) def testCreateOpNoScaleCenter(self): height, width, groups = 3, 3, 7 images = random_ops.random_uniform( - (5, height, width, 3*groups), dtype=dtypes.float32, seed=1) - output = normalization.group_norm(images, groups=groups, center=False, - scale=False) - self.assertListEqual([5, height, width, 3*groups], output.shape.as_list()) + (5, height, width, 3 * groups), dtype=dtypes.float32, seed=1) + output = normalization.group_norm( + images, groups=groups, center=False, scale=False) + self.assertListEqual([5, height, width, 3 * groups], output.shape.as_list()) self.assertEqual(0, len(contrib_variables.get_variables_by_name('beta'))) self.assertEqual(0, len(contrib_variables.get_variables_by_name('gamma'))) def testCreateVariables_NHWC(self): height, width = 3, 3 images = random_ops.random_uniform((5, height, width, 8), seed=1) - normalization.group_norm(images, groups=4, - channels_axis=-1, reduction_axes=(-3, -2), - center=True, scale=True) + normalization.group_norm( + images, + groups=4, + channels_axis=-1, + reduction_axes=(-3, -2), + center=True, + scale=True) beta = contrib_variables.get_variables_by_name('beta')[0] gamma = contrib_variables.get_variables_by_name('gamma')[0] self.assertEqual('GroupNorm/beta', beta.op.name) @@ -260,10 +259,14 @@ class GroupNormTest(test.TestCase): def testCreateVariables_NCHW(self): height, width, groups = 3, 3, 4 - images = random_ops.random_uniform((5, 2*groups, height, width), seed=1) - normalization.group_norm(images, groups=4, - channels_axis=-3, reduction_axes=(-2, -1), - center=True, scale=True) + images = random_ops.random_uniform((5, 2 * groups, height, width), seed=1) + normalization.group_norm( + images, + groups=4, + channels_axis=-3, + reduction_axes=(-2, -1), + center=True, + scale=True) beta = contrib_variables.get_variables_by_name('beta')[0] gamma = contrib_variables.get_variables_by_name('gamma')[0] self.assertEqual('GroupNorm/beta', beta.op.name) @@ -273,8 +276,8 @@ class GroupNormTest(test.TestCase): height, width = 3, 3 images = random_ops.random_uniform((5, height, width, 4), seed=1) normalization.group_norm(images, groups=2, scale=True, scope='IN') - normalization.group_norm(images, groups=2, scale=True, scope='IN', - reuse=True) + normalization.group_norm( + images, groups=2, scale=True, scope='IN', reuse=True) beta = contrib_variables.get_variables_by_name('beta') gamma = contrib_variables.get_variables_by_name('gamma') self.assertEqual(1, len(beta)) @@ -285,16 +288,21 @@ class GroupNormTest(test.TestCase): image_shape = (10, height, width, 4) images = random_ops.random_uniform(image_shape, seed=1) output_train = normalization.group_norm(images, groups=2, scope='IN') - output_eval = normalization.group_norm(images, groups=2, scope='IN', - reuse=True) + output_eval = normalization.group_norm( + images, groups=2, scope='IN', reuse=True) with self.test_session() as sess: sess.run(variables.global_variables_initializer()) # output_train and output_eval should be the same. train_np, eval_np = sess.run([output_train, output_eval]) self.assertAllClose(train_np, eval_np) - def doOutputTest(self, input_shape, channels_axis=None, reduction_axes=None, - groups=2, tol=1e-2): + def doOutputTest(self, + input_shape, + channels_axis=None, + reduction_axes=None, + mean_close_to_zero=False, + groups=2, + tol=1e-2): # Select the axis for the channel and the dimensions along which statistics # are accumulated. if channels_axis < 0: @@ -306,15 +314,16 @@ class GroupNormTest(test.TestCase): if a < channels_axis: reduced_axes.append(a) else: - reduced_axes.append(a+1) + reduced_axes.append(a + 1) reduced_axes = tuple(reduced_axes) # Calculate the final shape for the output Tensor. axes_before_channels = input_shape[:channels_axis] - axes_after_channels = input_shape[channels_axis+1:] + axes_after_channels = input_shape[channels_axis + 1:] channels = input_shape[channels_axis] - outputs_shape = (axes_before_channels + [groups, channels // groups] + - axes_after_channels) + outputs_shape = ( + axes_before_channels + [groups, channels // groups] + + axes_after_channels) # Calculate the final shape for the output statistics. reduced_shape = [] @@ -322,17 +331,28 @@ class GroupNormTest(test.TestCase): if i not in reduced_axes: reduced_shape.append(a) - for mu in (0.0, 1e2): - for sigma in (1.0, 0.1): + if mean_close_to_zero: + mu_tuple = (1e-4, 1e-2, 1.0) + sigma_tuple = (1e-2, 0.1, 1.0) + else: + mu_tuple = (1.0, 1e2) + sigma_tuple = (1.0, 0.1) + + for mu in mu_tuple: + for sigma in sigma_tuple: # Determine shape of Tensor after normalization. expected_mean = np.zeros(reduced_shape) expected_var = np.ones(reduced_shape) - inputs = random_ops.random_uniform(input_shape, seed=0) * sigma + mu + inputs = random_ops.random_normal(input_shape, seed=0) * sigma + mu output_op = normalization.group_norm( - inputs, groups=groups, center=False, scale=False, + inputs, + groups=groups, + center=False, + scale=False, channels_axis=channels_axis, - reduction_axes=reduction_axes) + reduction_axes=reduction_axes, + mean_close_to_zero=mean_close_to_zero) with self.test_session() as sess: sess.run(variables.global_variables_initializer()) outputs = sess.run(output_op) @@ -353,6 +373,18 @@ class GroupNormTest(test.TestCase): self.doOutputTest(input_shape, channels_axis=3, reduction_axes=[1, 2]) # Specify axes with negative values. self.doOutputTest(input_shape, channels_axis=-1, reduction_axes=[-3, -2]) + # Specify axes with positive values. + self.doOutputTest( + input_shape, + channels_axis=3, + reduction_axes=[1, 2], + mean_close_to_zero=True) + # Specify axes with negative values. + self.doOutputTest( + input_shape, + channels_axis=-1, + reduction_axes=[-3, -2], + mean_close_to_zero=True) def testOutputSmallInput3D_NHWC(self): input_shape = [10, 10, 30] @@ -360,6 +392,18 @@ class GroupNormTest(test.TestCase): self.doOutputTest(input_shape, channels_axis=2, reduction_axes=[0, 1]) # Specify axes with negative values. self.doOutputTest(input_shape, channels_axis=-1, reduction_axes=[-3, -2]) + # Specify axes with positive values. + self.doOutputTest( + input_shape, + channels_axis=2, + reduction_axes=[0, 1], + mean_close_to_zero=True) + # Specify axes with negative values. + self.doOutputTest( + input_shape, + channels_axis=-1, + reduction_axes=[-3, -2], + mean_close_to_zero=True) def testOutputSmallInput4D_NCHW(self): input_shape = [10, 10, 10, 30] @@ -367,6 +411,18 @@ class GroupNormTest(test.TestCase): self.doOutputTest(input_shape, channels_axis=1, reduction_axes=[2, 3]) # Specify axes with negative values. self.doOutputTest(input_shape, channels_axis=-3, reduction_axes=[-2, -1]) + # Specify axes with positive values. + self.doOutputTest( + input_shape, + channels_axis=1, + reduction_axes=[2, 3], + mean_close_to_zero=True) + # Specify axes with negative values. + self.doOutputTest( + input_shape, + channels_axis=-3, + reduction_axes=[-2, -1], + mean_close_to_zero=True) def testOutputSmallInput3D_NCHW(self): input_shape = [10, 10, 30] @@ -374,23 +430,62 @@ class GroupNormTest(test.TestCase): self.doOutputTest(input_shape, channels_axis=0, reduction_axes=[1, 2]) # Specify axes with negative values. self.doOutputTest(input_shape, channels_axis=-3, reduction_axes=[-2, -1]) + # Specify axes with positive values. + self.doOutputTest( + input_shape, + channels_axis=0, + reduction_axes=[1, 2], + mean_close_to_zero=True) + # Specify axes with negative values. + self.doOutputTest( + input_shape, + channels_axis=-3, + reduction_axes=[-2, -1], + mean_close_to_zero=True) def testOutputBigInput4D_NHWC(self): - self.doOutputTest([5, 100, 100, 1], channels_axis=3, reduction_axes=[1, 2], - groups=1) + self.doOutputTest( + [5, 100, 100, 1], channels_axis=3, reduction_axes=[1, 2], groups=1) + self.doOutputTest( + [5, 100, 100, 1], + channels_axis=3, + reduction_axes=[1, 2], + groups=1, + mean_close_to_zero=True) def testOutputBigInput4D_NCHW(self): - self.doOutputTest([1, 100, 100, 4], channels_axis=1, reduction_axes=[2, 3], - groups=4) + self.doOutputTest( + [1, 100, 100, 4], channels_axis=1, reduction_axes=[2, 3], groups=4) + self.doOutputTest( + [1, 100, 100, 4], + channels_axis=1, + reduction_axes=[2, 3], + groups=4, + mean_close_to_zero=True) def testOutputSmallInput2D_NC(self): - self.doOutputTest([10, 7*100], channels_axis=1, reduction_axes=[], groups=7) + self.doOutputTest( + [10, 7 * 100], channels_axis=1, reduction_axes=[], groups=7) + self.doOutputTest( + [10, 7 * 100], + channels_axis=1, + reduction_axes=[], + groups=7, + mean_close_to_zero=True) def testOutputSmallInput5D_NCXXX(self): - self.doOutputTest([10, 10, 20, 40, 5], - channels_axis=1, - reduction_axes=[2, 3, 4], - groups=5) + self.doOutputTest( + [10, 10, 20, 40, 5], + channels_axis=1, + reduction_axes=[2, 3, 4], + groups=5) + self.doOutputTest( + [10, 10, 20, 40, 5], + channels_axis=1, + reduction_axes=[2, 3, 4], + groups=5, + mean_close_to_zero=True) + if __name__ == '__main__': test.main() |