aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/layers
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-20 21:03:40 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-20 21:06:56 -0700
commit1fae2683be3ada987b8f487e96f62b4451df4393 (patch)
tree573af2a0071a4745dd8b812279579a8de476d26c /tensorflow/contrib/layers
parent7de22654844a41575a30cd1ce3a522abd0516fde (diff)
Add one pass algorithm option to calculate the mean and variance in group_norm. Fix normalization test in test fusion.
PiperOrigin-RevId: 209534762
Diffstat (limited to 'tensorflow/contrib/layers')
-rw-r--r--tensorflow/contrib/layers/BUILD2
-rw-r--r--tensorflow/contrib/layers/python/layers/normalization.py25
-rw-r--r--tensorflow/contrib/layers/python/layers/normalization_test.py100
3 files changed, 108 insertions, 19 deletions
diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD
index 7355a403ae..b4fe8cac74 100644
--- a/tensorflow/contrib/layers/BUILD
+++ b/tensorflow/contrib/layers/BUILD
@@ -185,7 +185,7 @@ py_test(
py_test(
name = "normalization_test",
- size = "small",
+ size = "medium",
srcs = ["python/layers/normalization_test.py"],
srcs_version = "PY2AND3",
tags = ["no_windows"], # TODO: needs investigation on Windows
diff --git a/tensorflow/contrib/layers/python/layers/normalization.py b/tensorflow/contrib/layers/python/layers/normalization.py
index c807ab0f2e..11033a2e9c 100644
--- a/tensorflow/contrib/layers/python/layers/normalization.py
+++ b/tensorflow/contrib/layers/python/layers/normalization.py
@@ -176,7 +176,8 @@ def group_norm(inputs,
variables_collections=None,
outputs_collections=None,
trainable=True,
- scope=None):
+ scope=None,
+ mean_close_to_zero=False):
"""Functional interface for the group normalization layer.
Reference: https://arxiv.org/abs/1803.08494.
@@ -222,6 +223,19 @@ def group_norm(inputs,
trainable: If `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
scope: Optional scope for `variable_scope`.
+ mean_close_to_zero: The mean of `input` before ReLU will be close to zero
+ when batch size >= 4k for Resnet-50 on TPU. If `True`, use
+ `nn.sufficient_statistics` and `nn.normalize_moments` to calculate the
+ variance. This is the same behavior as `fused` equals `True` in batch
+ normalization. If `False`, use `nn.moments` to calculate the variance.
+ When `mean` is close to zero, like 1e-4, use `mean` to calculate the
+ variance may have poor result due to repeated roundoff error and
+ denormalization in `mean`. When `mean` is large, like 1e2,
+ sum(`input`^2) is so large that only the high-order digits of the elements
+ are being accumulated. Thus, use sum(`input` - `mean`)^2/n to calculate
+ the variance has better accuracy compared to (sum(`input`^2)/n - `mean`^2)
+ when `mean` is large.
+
Returns:
A `Tensor` representing the output of the operation.
@@ -333,7 +347,14 @@ def group_norm(inputs,
gamma = array_ops.reshape(gamma, params_shape_broadcast)
# Calculate the moments.
- mean, variance = nn.moments(inputs, moments_axes, keep_dims=True)
+ if mean_close_to_zero:
+ # One pass algorithm returns better result when mean is close to zero.
+ counts, means_ss, variance_ss, _ = nn.sufficient_statistics(
+ inputs, moments_axes, keep_dims=True)
+ mean, variance = nn.normalize_moments(
+ counts, means_ss, variance_ss, shift=None)
+ else:
+ mean, variance = nn.moments(inputs, moments_axes, keep_dims=True)
# Compute normalization.
# TODO(shlens): Fix nn.batch_normalization to handle the 5-D Tensor
diff --git a/tensorflow/contrib/layers/python/layers/normalization_test.py b/tensorflow/contrib/layers/python/layers/normalization_test.py
index b6e96350db..55272e5fd1 100644
--- a/tensorflow/contrib/layers/python/layers/normalization_test.py
+++ b/tensorflow/contrib/layers/python/layers/normalization_test.py
@@ -293,8 +293,13 @@ class GroupNormTest(test.TestCase):
train_np, eval_np = sess.run([output_train, output_eval])
self.assertAllClose(train_np, eval_np)
- def doOutputTest(self, input_shape, channels_axis=None, reduction_axes=None,
- groups=2, tol=1e-2):
+ def doOutputTest(self,
+ input_shape,
+ channels_axis=None,
+ reduction_axes=None,
+ mean_close_to_zero=False,
+ groups=2,
+ tol=1e-2):
# Select the axis for the channel and the dimensions along which statistics
# are accumulated.
if channels_axis < 0:
@@ -322,17 +327,28 @@ class GroupNormTest(test.TestCase):
if i not in reduced_axes:
reduced_shape.append(a)
- for mu in (0.0, 1e2):
- for sigma in (1.0, 0.1):
+ if mean_close_to_zero:
+ mu_tuple = (1e-4, 1e-2, 1.0)
+ sigma_tuple = (1e-2, 0.1, 1.0)
+ else:
+ mu_tuple = (1.0, 1e2)
+ sigma_tuple = (1.0, 0.1)
+
+ for mu in mu_tuple:
+ for sigma in sigma_tuple:
# Determine shape of Tensor after normalization.
expected_mean = np.zeros(reduced_shape)
expected_var = np.ones(reduced_shape)
- inputs = random_ops.random_uniform(input_shape, seed=0) * sigma + mu
+ inputs = random_ops.random_normal(input_shape, seed=0) * sigma + mu
output_op = normalization.group_norm(
- inputs, groups=groups, center=False, scale=False,
+ inputs,
+ groups=groups,
+ center=False,
+ scale=False,
channels_axis=channels_axis,
- reduction_axes=reduction_axes)
+ reduction_axes=reduction_axes,
+ mean_close_to_zero=mean_close_to_zero)
with self.test_session() as sess:
sess.run(variables.global_variables_initializer())
outputs = sess.run(output_op)
@@ -347,12 +363,32 @@ class GroupNormTest(test.TestCase):
self.assertAllClose(expected_mean, mean, rtol=tol, atol=tol)
self.assertAllClose(expected_var, var, rtol=tol, atol=tol)
+ def doOutputTestForMeanCloseToZero(self,
+ input_shape,
+ channels_axis=None,
+ reduction_axes=None,
+ groups=2,
+ tol=5e-2):
+ self.doOutputTest(
+ input_shape,
+ channels_axis=channels_axis,
+ reduction_axes=reduction_axes,
+ groups=groups,
+ tol=tol,
+ mean_close_to_zero=True)
+
def testOutputSmallInput4D_NHWC(self):
input_shape = [10, 10, 10, 30]
# Specify axes with positive values.
self.doOutputTest(input_shape, channels_axis=3, reduction_axes=[1, 2])
# Specify axes with negative values.
self.doOutputTest(input_shape, channels_axis=-1, reduction_axes=[-3, -2])
+ # Specify axes with positive values.
+ self.doOutputTestForMeanCloseToZero(
+ input_shape, channels_axis=3, reduction_axes=[1, 2])
+ # Specify axes with negative values.
+ self.doOutputTestForMeanCloseToZero(
+ input_shape, channels_axis=-1, reduction_axes=[-3, -2])
def testOutputSmallInput3D_NHWC(self):
input_shape = [10, 10, 30]
@@ -360,6 +396,12 @@ class GroupNormTest(test.TestCase):
self.doOutputTest(input_shape, channels_axis=2, reduction_axes=[0, 1])
# Specify axes with negative values.
self.doOutputTest(input_shape, channels_axis=-1, reduction_axes=[-3, -2])
+ # Specify axes with positive values.
+ self.doOutputTestForMeanCloseToZero(
+ input_shape, channels_axis=2, reduction_axes=[0, 1])
+ # Specify axes with negative values.
+ self.doOutputTestForMeanCloseToZero(
+ input_shape, channels_axis=-1, reduction_axes=[-3, -2])
def testOutputSmallInput4D_NCHW(self):
input_shape = [10, 10, 10, 30]
@@ -367,6 +409,12 @@ class GroupNormTest(test.TestCase):
self.doOutputTest(input_shape, channels_axis=1, reduction_axes=[2, 3])
# Specify axes with negative values.
self.doOutputTest(input_shape, channels_axis=-3, reduction_axes=[-2, -1])
+ # Specify axes with positive values.
+ self.doOutputTestForMeanCloseToZero(
+ input_shape, channels_axis=1, reduction_axes=[2, 3])
+ # Specify axes with negative values.
+ self.doOutputTestForMeanCloseToZero(
+ input_shape, channels_axis=-3, reduction_axes=[-2, -1])
def testOutputSmallInput3D_NCHW(self):
input_shape = [10, 10, 30]
@@ -374,23 +422,43 @@ class GroupNormTest(test.TestCase):
self.doOutputTest(input_shape, channels_axis=0, reduction_axes=[1, 2])
# Specify axes with negative values.
self.doOutputTest(input_shape, channels_axis=-3, reduction_axes=[-2, -1])
+ # Specify axes with positive values.
+ self.doOutputTestForMeanCloseToZero(
+ input_shape, channels_axis=0, reduction_axes=[1, 2])
+ # Specify axes with negative values.
+ self.doOutputTestForMeanCloseToZero(
+ input_shape, channels_axis=-3, reduction_axes=[-2, -1])
def testOutputBigInput4D_NHWC(self):
- self.doOutputTest([5, 100, 100, 1], channels_axis=3, reduction_axes=[1, 2],
- groups=1)
+ self.doOutputTest(
+ [5, 100, 100, 1], channels_axis=3, reduction_axes=[1, 2], groups=1)
+ self.doOutputTestForMeanCloseToZero(
+ [5, 100, 100, 1], channels_axis=3, reduction_axes=[1, 2], groups=1)
def testOutputBigInput4D_NCHW(self):
- self.doOutputTest([1, 100, 100, 4], channels_axis=1, reduction_axes=[2, 3],
- groups=4)
+ self.doOutputTest(
+ [1, 100, 100, 4], channels_axis=1, reduction_axes=[2, 3], groups=4)
+ self.doOutputTestForMeanCloseToZero(
+ [1, 100, 100, 4], channels_axis=1, reduction_axes=[2, 3], groups=4)
def testOutputSmallInput2D_NC(self):
- self.doOutputTest([10, 7*100], channels_axis=1, reduction_axes=[], groups=7)
+ self.doOutputTest(
+ [10, 7 * 100], channels_axis=1, reduction_axes=[], groups=7)
+ self.doOutputTestForMeanCloseToZero(
+ [10, 7 * 100], channels_axis=1, reduction_axes=[], groups=7)
def testOutputSmallInput5D_NCXXX(self):
- self.doOutputTest([10, 10, 20, 40, 5],
- channels_axis=1,
- reduction_axes=[2, 3, 4],
- groups=5)
+ self.doOutputTest(
+ [10, 10, 20, 40, 5],
+ channels_axis=1,
+ reduction_axes=[2, 3, 4],
+ groups=5)
+ self.doOutputTestForMeanCloseToZero(
+ [10, 10, 20, 40, 5],
+ channels_axis=1,
+ reduction_axes=[2, 3, 4],
+ groups=5)
+
if __name__ == '__main__':
test.main()