diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-20 21:03:40 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-20 21:06:56 -0700 |
commit | 1fae2683be3ada987b8f487e96f62b4451df4393 (patch) | |
tree | 573af2a0071a4745dd8b812279579a8de476d26c /tensorflow/contrib/layers | |
parent | 7de22654844a41575a30cd1ce3a522abd0516fde (diff) |
Add one pass algorithm option to calculate the mean and variance in group_norm. Fix normalization test in test fusion.
PiperOrigin-RevId: 209534762
Diffstat (limited to 'tensorflow/contrib/layers')
-rw-r--r-- | tensorflow/contrib/layers/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/contrib/layers/python/layers/normalization.py | 25 | ||||
-rw-r--r-- | tensorflow/contrib/layers/python/layers/normalization_test.py | 100 |
3 files changed, 108 insertions, 19 deletions
diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD index 7355a403ae..b4fe8cac74 100644 --- a/tensorflow/contrib/layers/BUILD +++ b/tensorflow/contrib/layers/BUILD @@ -185,7 +185,7 @@ py_test( py_test( name = "normalization_test", - size = "small", + size = "medium", srcs = ["python/layers/normalization_test.py"], srcs_version = "PY2AND3", tags = ["no_windows"], # TODO: needs investigation on Windows diff --git a/tensorflow/contrib/layers/python/layers/normalization.py b/tensorflow/contrib/layers/python/layers/normalization.py index c807ab0f2e..11033a2e9c 100644 --- a/tensorflow/contrib/layers/python/layers/normalization.py +++ b/tensorflow/contrib/layers/python/layers/normalization.py @@ -176,7 +176,8 @@ def group_norm(inputs, variables_collections=None, outputs_collections=None, trainable=True, - scope=None): + scope=None, + mean_close_to_zero=False): """Functional interface for the group normalization layer. Reference: https://arxiv.org/abs/1803.08494. @@ -222,6 +223,19 @@ def group_norm(inputs, trainable: If `True` also add variables to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). scope: Optional scope for `variable_scope`. + mean_close_to_zero: The mean of `input` before ReLU will be close to zero + when batch size >= 4k for Resnet-50 on TPU. If `True`, use + `nn.sufficient_statistics` and `nn.normalize_moments` to calculate the + variance. This is the same behavior as `fused` equals `True` in batch + normalization. If `False`, use `nn.moments` to calculate the variance. + When `mean` is close to zero, like 1e-4, use `mean` to calculate the + variance may have poor result due to repeated roundoff error and + denormalization in `mean`. When `mean` is large, like 1e2, + sum(`input`^2) is so large that only the high-order digits of the elements + are being accumulated. Thus, use sum(`input` - `mean`)^2/n to calculate + the variance has better accuracy compared to (sum(`input`^2)/n - `mean`^2) + when `mean` is large. + Returns: A `Tensor` representing the output of the operation. @@ -333,7 +347,14 @@ def group_norm(inputs, gamma = array_ops.reshape(gamma, params_shape_broadcast) # Calculate the moments. - mean, variance = nn.moments(inputs, moments_axes, keep_dims=True) + if mean_close_to_zero: + # One pass algorithm returns better result when mean is close to zero. + counts, means_ss, variance_ss, _ = nn.sufficient_statistics( + inputs, moments_axes, keep_dims=True) + mean, variance = nn.normalize_moments( + counts, means_ss, variance_ss, shift=None) + else: + mean, variance = nn.moments(inputs, moments_axes, keep_dims=True) # Compute normalization. # TODO(shlens): Fix nn.batch_normalization to handle the 5-D Tensor diff --git a/tensorflow/contrib/layers/python/layers/normalization_test.py b/tensorflow/contrib/layers/python/layers/normalization_test.py index b6e96350db..55272e5fd1 100644 --- a/tensorflow/contrib/layers/python/layers/normalization_test.py +++ b/tensorflow/contrib/layers/python/layers/normalization_test.py @@ -293,8 +293,13 @@ class GroupNormTest(test.TestCase): train_np, eval_np = sess.run([output_train, output_eval]) self.assertAllClose(train_np, eval_np) - def doOutputTest(self, input_shape, channels_axis=None, reduction_axes=None, - groups=2, tol=1e-2): + def doOutputTest(self, + input_shape, + channels_axis=None, + reduction_axes=None, + mean_close_to_zero=False, + groups=2, + tol=1e-2): # Select the axis for the channel and the dimensions along which statistics # are accumulated. if channels_axis < 0: @@ -322,17 +327,28 @@ class GroupNormTest(test.TestCase): if i not in reduced_axes: reduced_shape.append(a) - for mu in (0.0, 1e2): - for sigma in (1.0, 0.1): + if mean_close_to_zero: + mu_tuple = (1e-4, 1e-2, 1.0) + sigma_tuple = (1e-2, 0.1, 1.0) + else: + mu_tuple = (1.0, 1e2) + sigma_tuple = (1.0, 0.1) + + for mu in mu_tuple: + for sigma in sigma_tuple: # Determine shape of Tensor after normalization. expected_mean = np.zeros(reduced_shape) expected_var = np.ones(reduced_shape) - inputs = random_ops.random_uniform(input_shape, seed=0) * sigma + mu + inputs = random_ops.random_normal(input_shape, seed=0) * sigma + mu output_op = normalization.group_norm( - inputs, groups=groups, center=False, scale=False, + inputs, + groups=groups, + center=False, + scale=False, channels_axis=channels_axis, - reduction_axes=reduction_axes) + reduction_axes=reduction_axes, + mean_close_to_zero=mean_close_to_zero) with self.test_session() as sess: sess.run(variables.global_variables_initializer()) outputs = sess.run(output_op) @@ -347,12 +363,32 @@ class GroupNormTest(test.TestCase): self.assertAllClose(expected_mean, mean, rtol=tol, atol=tol) self.assertAllClose(expected_var, var, rtol=tol, atol=tol) + def doOutputTestForMeanCloseToZero(self, + input_shape, + channels_axis=None, + reduction_axes=None, + groups=2, + tol=5e-2): + self.doOutputTest( + input_shape, + channels_axis=channels_axis, + reduction_axes=reduction_axes, + groups=groups, + tol=tol, + mean_close_to_zero=True) + def testOutputSmallInput4D_NHWC(self): input_shape = [10, 10, 10, 30] # Specify axes with positive values. self.doOutputTest(input_shape, channels_axis=3, reduction_axes=[1, 2]) # Specify axes with negative values. self.doOutputTest(input_shape, channels_axis=-1, reduction_axes=[-3, -2]) + # Specify axes with positive values. + self.doOutputTestForMeanCloseToZero( + input_shape, channels_axis=3, reduction_axes=[1, 2]) + # Specify axes with negative values. + self.doOutputTestForMeanCloseToZero( + input_shape, channels_axis=-1, reduction_axes=[-3, -2]) def testOutputSmallInput3D_NHWC(self): input_shape = [10, 10, 30] @@ -360,6 +396,12 @@ class GroupNormTest(test.TestCase): self.doOutputTest(input_shape, channels_axis=2, reduction_axes=[0, 1]) # Specify axes with negative values. self.doOutputTest(input_shape, channels_axis=-1, reduction_axes=[-3, -2]) + # Specify axes with positive values. + self.doOutputTestForMeanCloseToZero( + input_shape, channels_axis=2, reduction_axes=[0, 1]) + # Specify axes with negative values. + self.doOutputTestForMeanCloseToZero( + input_shape, channels_axis=-1, reduction_axes=[-3, -2]) def testOutputSmallInput4D_NCHW(self): input_shape = [10, 10, 10, 30] @@ -367,6 +409,12 @@ class GroupNormTest(test.TestCase): self.doOutputTest(input_shape, channels_axis=1, reduction_axes=[2, 3]) # Specify axes with negative values. self.doOutputTest(input_shape, channels_axis=-3, reduction_axes=[-2, -1]) + # Specify axes with positive values. + self.doOutputTestForMeanCloseToZero( + input_shape, channels_axis=1, reduction_axes=[2, 3]) + # Specify axes with negative values. + self.doOutputTestForMeanCloseToZero( + input_shape, channels_axis=-3, reduction_axes=[-2, -1]) def testOutputSmallInput3D_NCHW(self): input_shape = [10, 10, 30] @@ -374,23 +422,43 @@ class GroupNormTest(test.TestCase): self.doOutputTest(input_shape, channels_axis=0, reduction_axes=[1, 2]) # Specify axes with negative values. self.doOutputTest(input_shape, channels_axis=-3, reduction_axes=[-2, -1]) + # Specify axes with positive values. + self.doOutputTestForMeanCloseToZero( + input_shape, channels_axis=0, reduction_axes=[1, 2]) + # Specify axes with negative values. + self.doOutputTestForMeanCloseToZero( + input_shape, channels_axis=-3, reduction_axes=[-2, -1]) def testOutputBigInput4D_NHWC(self): - self.doOutputTest([5, 100, 100, 1], channels_axis=3, reduction_axes=[1, 2], - groups=1) + self.doOutputTest( + [5, 100, 100, 1], channels_axis=3, reduction_axes=[1, 2], groups=1) + self.doOutputTestForMeanCloseToZero( + [5, 100, 100, 1], channels_axis=3, reduction_axes=[1, 2], groups=1) def testOutputBigInput4D_NCHW(self): - self.doOutputTest([1, 100, 100, 4], channels_axis=1, reduction_axes=[2, 3], - groups=4) + self.doOutputTest( + [1, 100, 100, 4], channels_axis=1, reduction_axes=[2, 3], groups=4) + self.doOutputTestForMeanCloseToZero( + [1, 100, 100, 4], channels_axis=1, reduction_axes=[2, 3], groups=4) def testOutputSmallInput2D_NC(self): - self.doOutputTest([10, 7*100], channels_axis=1, reduction_axes=[], groups=7) + self.doOutputTest( + [10, 7 * 100], channels_axis=1, reduction_axes=[], groups=7) + self.doOutputTestForMeanCloseToZero( + [10, 7 * 100], channels_axis=1, reduction_axes=[], groups=7) def testOutputSmallInput5D_NCXXX(self): - self.doOutputTest([10, 10, 20, 40, 5], - channels_axis=1, - reduction_axes=[2, 3, 4], - groups=5) + self.doOutputTest( + [10, 10, 20, 40, 5], + channels_axis=1, + reduction_axes=[2, 3, 4], + groups=5) + self.doOutputTestForMeanCloseToZero( + [10, 10, 20, 40, 5], + channels_axis=1, + reduction_axes=[2, 3, 4], + groups=5) + if __name__ == '__main__': test.main() |