aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/layers
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-16 20:57:28 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-16 21:01:34 -0700
commita1606d5e0f667fddd7f3f5705bda3aee5b3c2554 (patch)
tree76e2f1d35826962394a9b58905ba5095c7f35ebf /tensorflow/contrib/layers
parent3e96a135d9650c91e307d1d56c81e5a37078cada (diff)
Add one pass algorithm option to calculate the mean and variance in group_norm
PiperOrigin-RevId: 209096783
Diffstat (limited to 'tensorflow/contrib/layers')
-rw-r--r--tensorflow/contrib/layers/BUILD2
-rw-r--r--tensorflow/contrib/layers/python/layers/normalization.py25
-rw-r--r--tensorflow/contrib/layers/python/layers/normalization_test.py215
3 files changed, 179 insertions, 63 deletions
diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD
index 7355a403ae..b4fe8cac74 100644
--- a/tensorflow/contrib/layers/BUILD
+++ b/tensorflow/contrib/layers/BUILD
@@ -185,7 +185,7 @@ py_test(
py_test(
name = "normalization_test",
- size = "small",
+ size = "medium",
srcs = ["python/layers/normalization_test.py"],
srcs_version = "PY2AND3",
tags = ["no_windows"], # TODO: needs investigation on Windows
diff --git a/tensorflow/contrib/layers/python/layers/normalization.py b/tensorflow/contrib/layers/python/layers/normalization.py
index c807ab0f2e..85bd549393 100644
--- a/tensorflow/contrib/layers/python/layers/normalization.py
+++ b/tensorflow/contrib/layers/python/layers/normalization.py
@@ -176,7 +176,8 @@ def group_norm(inputs,
variables_collections=None,
outputs_collections=None,
trainable=True,
- scope=None):
+ scope=None,
+ mean_close_to_zero=False):
"""Functional interface for the group normalization layer.
Reference: https://arxiv.org/abs/1803.08494.
@@ -222,6 +223,19 @@ def group_norm(inputs,
trainable: If `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
scope: Optional scope for `variable_scope`.
+ mean_close_to_zero: The mean of `input` before ReLU will be close to zero
+ when batch size >= 4k for Resnet-50 on TPU. If `True`, use
+ `nn.sufficient_statistics` and `nn.normalize_moments` to calculate the
+ variance. This is the same behavior as `fused` equals `True` in batch
+ normalization. If `False`, use `nn.moments` to calculate the variance.
+ When `mean` is close to zero, like 1e-4, use `mean` to calculate the
+ variance may have poor result due to repeated roundoff error and
+ denormalization in `mean`. When `mean` is large, like 1e2,
+ sum(`input`^2) is so large that only the high-order digits of the elements
+ are being accumulated. Thus, use sum(`input` - `mean`)^2/n to calcualte
+ the variance has better accuracy compared to (sum(`input`^2)/n - `mean`^2)
+ when `mean` is large.
+
Returns:
A `Tensor` representing the output of the operation.
@@ -333,7 +347,14 @@ def group_norm(inputs,
gamma = array_ops.reshape(gamma, params_shape_broadcast)
# Calculate the moments.
- mean, variance = nn.moments(inputs, moments_axes, keep_dims=True)
+ if mean_close_to_zero:
+ # One pass algorithm returns better result when mean is close to zero.
+ counts, means_ss, variance_ss, _ = nn.sufficient_statistics(
+ inputs, moments_axes, keep_dims=True)
+ mean, variance = nn.normalize_moments(
+ counts, means_ss, variance_ss, shift=None)
+ else:
+ mean, variance = nn.moments(inputs, moments_axes, keep_dims=True)
# Compute normalization.
# TODO(shlens): Fix nn.batch_normalization to handle the 5-D Tensor
diff --git a/tensorflow/contrib/layers/python/layers/normalization_test.py b/tensorflow/contrib/layers/python/layers/normalization_test.py
index b6e96350db..5d8c899eec 100644
--- a/tensorflow/contrib/layers/python/layers/normalization_test.py
+++ b/tensorflow/contrib/layers/python/layers/normalization_test.py
@@ -57,8 +57,7 @@ class InstanceNormTest(test.TestCase):
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = normalization.instance_norm(images)
print('name: ', output.op.name)
- self.assertStartsWith(
- output.op.name, 'InstanceNorm/instancenorm')
+ self.assertStartsWith(output.op.name, 'InstanceNorm/instancenorm')
self.assertListEqual([5, height, width, 3], output.shape.as_list())
def testCreateOpFloat64(self):
@@ -66,8 +65,7 @@ class InstanceNormTest(test.TestCase):
images = random_ops.random_uniform(
(5, height, width, 3), dtype=dtypes.float64, seed=1)
output = normalization.instance_norm(images)
- self.assertStartsWith(
- output.op.name, 'InstanceNorm/instancenorm')
+ self.assertStartsWith(output.op.name, 'InstanceNorm/instancenorm')
self.assertListEqual([5, height, width, 3], output.shape.as_list())
def testCreateOpNoScaleCenter(self):
@@ -75,8 +73,7 @@ class InstanceNormTest(test.TestCase):
images = random_ops.random_uniform(
(5, height, width, 3), dtype=dtypes.float64, seed=1)
output = normalization.instance_norm(images, center=False, scale=False)
- self.assertStartsWith(
- output.op.name, 'InstanceNorm/instancenorm')
+ self.assertStartsWith(output.op.name, 'InstanceNorm/instancenorm')
self.assertListEqual([5, height, width, 3], output.shape.as_list())
self.assertEqual(0, len(contrib_variables.get_variables_by_name('beta')))
self.assertEqual(0, len(contrib_variables.get_variables_by_name('gamma')))
@@ -173,24 +170,22 @@ class GroupNormTest(test.TestCase):
inputs = array_ops.placeholder(dtypes.float32, shape=(5, 2, 10, 10))
with self.assertRaisesRegexp(ValueError,
'Invalid groups 10 for 2 channels.'):
- normalization.group_norm(inputs, groups=10,
- reduction_axes=[-2, -1], channels_axis=-3)
+ normalization.group_norm(
+ inputs, groups=10, reduction_axes=[-2, -1], channels_axis=-3)
def testBadCommensurateGroup(self):
inputs = array_ops.placeholder(dtypes.float32, shape=(5, 4, 10, 10))
- with self.assertRaisesRegexp(ValueError,
- '4 channels is not commensurate with '
- '3 groups.'):
- normalization.group_norm(inputs, groups=3,
- reduction_axes=[-2, -1], channels_axis=-3)
+ with self.assertRaisesRegexp(
+ ValueError, '4 channels is not commensurate with '
+ '3 groups.'):
+ normalization.group_norm(
+ inputs, groups=3, reduction_axes=[-2, -1], channels_axis=-3)
def testAxisIsBad(self):
inputs = array_ops.placeholder(dtypes.float32, shape=(1, 2, 4, 5))
- with self.assertRaisesRegexp(ValueError,
- 'Axis is out of bounds.'):
+ with self.assertRaisesRegexp(ValueError, 'Axis is out of bounds.'):
normalization.group_norm(inputs, channels_axis=5)
- with self.assertRaisesRegexp(ValueError,
- 'Axis is out of bounds.'):
+ with self.assertRaisesRegexp(ValueError, 'Axis is out of bounds.'):
normalization.group_norm(inputs, reduction_axes=[1, 5])
def testNotMutuallyExclusiveAxis(self):
@@ -218,41 +213,45 @@ class GroupNormTest(test.TestCase):
def testParamsShapeNotFullyDefinedChannelsAxis(self):
inputs = array_ops.placeholder(dtypes.float32, shape=(1, 3, 4, None))
with self.assertRaisesRegexp(ValueError, 'undefined channel dimension'):
- normalization.group_norm(inputs, channels_axis=-1,
- reduction_axes=[-3, -2])
+ normalization.group_norm(
+ inputs, channels_axis=-1, reduction_axes=[-3, -2])
def testCreateOp(self):
height, width, groups = 3, 3, 4
- images = random_ops.random_uniform((5, height, width, 2*groups), seed=1)
- output = normalization.group_norm(images, groups=groups, channels_axis=-1,
- reduction_axes=[-3, -2])
+ images = random_ops.random_uniform((5, height, width, 2 * groups), seed=1)
+ output = normalization.group_norm(
+ images, groups=groups, channels_axis=-1, reduction_axes=[-3, -2])
print('name: ', output.op.name)
- self.assertListEqual([5, height, width, 2*groups], output.shape.as_list())
+ self.assertListEqual([5, height, width, 2 * groups], output.shape.as_list())
def testCreateOpFloat64(self):
height, width, groups = 3, 3, 5
images = random_ops.random_uniform(
- (5, height, width, 4*groups), dtype=dtypes.float64, seed=1)
+ (5, height, width, 4 * groups), dtype=dtypes.float64, seed=1)
output = normalization.group_norm(images, groups=groups)
self.assertEqual(dtypes.float64, output.dtype)
- self.assertListEqual([5, height, width, 4*groups], output.shape.as_list())
+ self.assertListEqual([5, height, width, 4 * groups], output.shape.as_list())
def testCreateOpNoScaleCenter(self):
height, width, groups = 3, 3, 7
images = random_ops.random_uniform(
- (5, height, width, 3*groups), dtype=dtypes.float32, seed=1)
- output = normalization.group_norm(images, groups=groups, center=False,
- scale=False)
- self.assertListEqual([5, height, width, 3*groups], output.shape.as_list())
+ (5, height, width, 3 * groups), dtype=dtypes.float32, seed=1)
+ output = normalization.group_norm(
+ images, groups=groups, center=False, scale=False)
+ self.assertListEqual([5, height, width, 3 * groups], output.shape.as_list())
self.assertEqual(0, len(contrib_variables.get_variables_by_name('beta')))
self.assertEqual(0, len(contrib_variables.get_variables_by_name('gamma')))
def testCreateVariables_NHWC(self):
height, width = 3, 3
images = random_ops.random_uniform((5, height, width, 8), seed=1)
- normalization.group_norm(images, groups=4,
- channels_axis=-1, reduction_axes=(-3, -2),
- center=True, scale=True)
+ normalization.group_norm(
+ images,
+ groups=4,
+ channels_axis=-1,
+ reduction_axes=(-3, -2),
+ center=True,
+ scale=True)
beta = contrib_variables.get_variables_by_name('beta')[0]
gamma = contrib_variables.get_variables_by_name('gamma')[0]
self.assertEqual('GroupNorm/beta', beta.op.name)
@@ -260,10 +259,14 @@ class GroupNormTest(test.TestCase):
def testCreateVariables_NCHW(self):
height, width, groups = 3, 3, 4
- images = random_ops.random_uniform((5, 2*groups, height, width), seed=1)
- normalization.group_norm(images, groups=4,
- channels_axis=-3, reduction_axes=(-2, -1),
- center=True, scale=True)
+ images = random_ops.random_uniform((5, 2 * groups, height, width), seed=1)
+ normalization.group_norm(
+ images,
+ groups=4,
+ channels_axis=-3,
+ reduction_axes=(-2, -1),
+ center=True,
+ scale=True)
beta = contrib_variables.get_variables_by_name('beta')[0]
gamma = contrib_variables.get_variables_by_name('gamma')[0]
self.assertEqual('GroupNorm/beta', beta.op.name)
@@ -273,8 +276,8 @@ class GroupNormTest(test.TestCase):
height, width = 3, 3
images = random_ops.random_uniform((5, height, width, 4), seed=1)
normalization.group_norm(images, groups=2, scale=True, scope='IN')
- normalization.group_norm(images, groups=2, scale=True, scope='IN',
- reuse=True)
+ normalization.group_norm(
+ images, groups=2, scale=True, scope='IN', reuse=True)
beta = contrib_variables.get_variables_by_name('beta')
gamma = contrib_variables.get_variables_by_name('gamma')
self.assertEqual(1, len(beta))
@@ -285,16 +288,21 @@ class GroupNormTest(test.TestCase):
image_shape = (10, height, width, 4)
images = random_ops.random_uniform(image_shape, seed=1)
output_train = normalization.group_norm(images, groups=2, scope='IN')
- output_eval = normalization.group_norm(images, groups=2, scope='IN',
- reuse=True)
+ output_eval = normalization.group_norm(
+ images, groups=2, scope='IN', reuse=True)
with self.test_session() as sess:
sess.run(variables.global_variables_initializer())
# output_train and output_eval should be the same.
train_np, eval_np = sess.run([output_train, output_eval])
self.assertAllClose(train_np, eval_np)
- def doOutputTest(self, input_shape, channels_axis=None, reduction_axes=None,
- groups=2, tol=1e-2):
+ def doOutputTest(self,
+ input_shape,
+ channels_axis=None,
+ reduction_axes=None,
+ mean_close_to_zero=False,
+ groups=2,
+ tol=1e-2):
# Select the axis for the channel and the dimensions along which statistics
# are accumulated.
if channels_axis < 0:
@@ -306,15 +314,16 @@ class GroupNormTest(test.TestCase):
if a < channels_axis:
reduced_axes.append(a)
else:
- reduced_axes.append(a+1)
+ reduced_axes.append(a + 1)
reduced_axes = tuple(reduced_axes)
# Calculate the final shape for the output Tensor.
axes_before_channels = input_shape[:channels_axis]
- axes_after_channels = input_shape[channels_axis+1:]
+ axes_after_channels = input_shape[channels_axis + 1:]
channels = input_shape[channels_axis]
- outputs_shape = (axes_before_channels + [groups, channels // groups] +
- axes_after_channels)
+ outputs_shape = (
+ axes_before_channels + [groups, channels // groups] +
+ axes_after_channels)
# Calculate the final shape for the output statistics.
reduced_shape = []
@@ -322,17 +331,28 @@ class GroupNormTest(test.TestCase):
if i not in reduced_axes:
reduced_shape.append(a)
- for mu in (0.0, 1e2):
- for sigma in (1.0, 0.1):
+ if mean_close_to_zero:
+ mu_tuple = (1e-4, 1e-2, 1.0)
+ sigma_tuple = (1e-2, 0.1, 1.0)
+ else:
+ mu_tuple = (1.0, 1e2)
+ sigma_tuple = (1.0, 0.1)
+
+ for mu in mu_tuple:
+ for sigma in sigma_tuple:
# Determine shape of Tensor after normalization.
expected_mean = np.zeros(reduced_shape)
expected_var = np.ones(reduced_shape)
- inputs = random_ops.random_uniform(input_shape, seed=0) * sigma + mu
+ inputs = random_ops.random_normal(input_shape, seed=0) * sigma + mu
output_op = normalization.group_norm(
- inputs, groups=groups, center=False, scale=False,
+ inputs,
+ groups=groups,
+ center=False,
+ scale=False,
channels_axis=channels_axis,
- reduction_axes=reduction_axes)
+ reduction_axes=reduction_axes,
+ mean_close_to_zero=mean_close_to_zero)
with self.test_session() as sess:
sess.run(variables.global_variables_initializer())
outputs = sess.run(output_op)
@@ -353,6 +373,18 @@ class GroupNormTest(test.TestCase):
self.doOutputTest(input_shape, channels_axis=3, reduction_axes=[1, 2])
# Specify axes with negative values.
self.doOutputTest(input_shape, channels_axis=-1, reduction_axes=[-3, -2])
+ # Specify axes with positive values.
+ self.doOutputTest(
+ input_shape,
+ channels_axis=3,
+ reduction_axes=[1, 2],
+ mean_close_to_zero=True)
+ # Specify axes with negative values.
+ self.doOutputTest(
+ input_shape,
+ channels_axis=-1,
+ reduction_axes=[-3, -2],
+ mean_close_to_zero=True)
def testOutputSmallInput3D_NHWC(self):
input_shape = [10, 10, 30]
@@ -360,6 +392,18 @@ class GroupNormTest(test.TestCase):
self.doOutputTest(input_shape, channels_axis=2, reduction_axes=[0, 1])
# Specify axes with negative values.
self.doOutputTest(input_shape, channels_axis=-1, reduction_axes=[-3, -2])
+ # Specify axes with positive values.
+ self.doOutputTest(
+ input_shape,
+ channels_axis=2,
+ reduction_axes=[0, 1],
+ mean_close_to_zero=True)
+ # Specify axes with negative values.
+ self.doOutputTest(
+ input_shape,
+ channels_axis=-1,
+ reduction_axes=[-3, -2],
+ mean_close_to_zero=True)
def testOutputSmallInput4D_NCHW(self):
input_shape = [10, 10, 10, 30]
@@ -367,6 +411,18 @@ class GroupNormTest(test.TestCase):
self.doOutputTest(input_shape, channels_axis=1, reduction_axes=[2, 3])
# Specify axes with negative values.
self.doOutputTest(input_shape, channels_axis=-3, reduction_axes=[-2, -1])
+ # Specify axes with positive values.
+ self.doOutputTest(
+ input_shape,
+ channels_axis=1,
+ reduction_axes=[2, 3],
+ mean_close_to_zero=True)
+ # Specify axes with negative values.
+ self.doOutputTest(
+ input_shape,
+ channels_axis=-3,
+ reduction_axes=[-2, -1],
+ mean_close_to_zero=True)
def testOutputSmallInput3D_NCHW(self):
input_shape = [10, 10, 30]
@@ -374,23 +430,62 @@ class GroupNormTest(test.TestCase):
self.doOutputTest(input_shape, channels_axis=0, reduction_axes=[1, 2])
# Specify axes with negative values.
self.doOutputTest(input_shape, channels_axis=-3, reduction_axes=[-2, -1])
+ # Specify axes with positive values.
+ self.doOutputTest(
+ input_shape,
+ channels_axis=0,
+ reduction_axes=[1, 2],
+ mean_close_to_zero=True)
+ # Specify axes with negative values.
+ self.doOutputTest(
+ input_shape,
+ channels_axis=-3,
+ reduction_axes=[-2, -1],
+ mean_close_to_zero=True)
def testOutputBigInput4D_NHWC(self):
- self.doOutputTest([5, 100, 100, 1], channels_axis=3, reduction_axes=[1, 2],
- groups=1)
+ self.doOutputTest(
+ [5, 100, 100, 1], channels_axis=3, reduction_axes=[1, 2], groups=1)
+ self.doOutputTest(
+ [5, 100, 100, 1],
+ channels_axis=3,
+ reduction_axes=[1, 2],
+ groups=1,
+ mean_close_to_zero=True)
def testOutputBigInput4D_NCHW(self):
- self.doOutputTest([1, 100, 100, 4], channels_axis=1, reduction_axes=[2, 3],
- groups=4)
+ self.doOutputTest(
+ [1, 100, 100, 4], channels_axis=1, reduction_axes=[2, 3], groups=4)
+ self.doOutputTest(
+ [1, 100, 100, 4],
+ channels_axis=1,
+ reduction_axes=[2, 3],
+ groups=4,
+ mean_close_to_zero=True)
def testOutputSmallInput2D_NC(self):
- self.doOutputTest([10, 7*100], channels_axis=1, reduction_axes=[], groups=7)
+ self.doOutputTest(
+ [10, 7 * 100], channels_axis=1, reduction_axes=[], groups=7)
+ self.doOutputTest(
+ [10, 7 * 100],
+ channels_axis=1,
+ reduction_axes=[],
+ groups=7,
+ mean_close_to_zero=True)
def testOutputSmallInput5D_NCXXX(self):
- self.doOutputTest([10, 10, 20, 40, 5],
- channels_axis=1,
- reduction_axes=[2, 3, 4],
- groups=5)
+ self.doOutputTest(
+ [10, 10, 20, 40, 5],
+ channels_axis=1,
+ reduction_axes=[2, 3, 4],
+ groups=5)
+ self.doOutputTest(
+ [10, 10, 20, 40, 5],
+ channels_axis=1,
+ reduction_axes=[2, 3, 4],
+ groups=5,
+ mean_close_to_zero=True)
+
if __name__ == '__main__':
test.main()