diff options
Diffstat (limited to 'tensorflow/python/layers/normalization_test.py')
-rw-r--r-- | tensorflow/python/layers/normalization_test.py | 59 |
1 files changed, 59 insertions, 0 deletions
diff --git a/tensorflow/python/layers/normalization_test.py b/tensorflow/python/layers/normalization_test.py index 91b7cb6f48..0f82f73ea4 100644 --- a/tensorflow/python/layers/normalization_test.py +++ b/tensorflow/python/layers/normalization_test.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.layers import normalization as normalization_layers from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope @@ -513,6 +514,64 @@ class BNTest(test.TestCase): _ = bn.apply(inputs, training=training) self.assertEqual(len(bn.losses), 1) + def testRenorm(self): + shape = (4, 3) + xt = array_ops.placeholder(dtypes.float32, shape) + momentum = 0.99 + renorm_momentum = 0.8 + rmax = 1.1 + rmin = 0.9 + dmax = 0.1 + gamma = 2. + beta = 3. + epsilon = 0.001 + bn = normalization_layers.BatchNormalization( + axis=1, + gamma_initializer=init_ops.constant_initializer(gamma), + beta_initializer=init_ops.constant_initializer(beta), + epsilon=epsilon, + momentum=momentum, + renorm=True, + renorm_clipping={'rmax': rmax, 'rmin': rmin, 'dmax': dmax}, + renorm_momentum=renorm_momentum) + training = array_ops.placeholder(dtypes.bool) + yt = bn.apply(xt, training=training) + + moving_mean = 0. + moving_variance = 1. + renorm_mean = renorm_stddev = 0. + renorm_weight = 0. + with self.test_session(use_gpu=True) as sess: + sess.run(variables.global_variables_initializer()) + for _ in range(5): + x = np.random.random(shape) + + mean = x.mean(0) + stddev = np.sqrt(x.var(0) + epsilon) + adj_mean = renorm_mean + (1. - renorm_weight) * mean + adj_stddev = renorm_stddev + (1. - renorm_weight) * stddev + r = (stddev / adj_stddev).clip(rmin, rmax) + d = ((mean - adj_mean) / adj_stddev).clip(-dmax, dmax) + y_train = ((x - mean) / stddev * r + d) * gamma + beta + renorm_mean += (mean - renorm_mean) * (1. - renorm_momentum) + renorm_stddev += (stddev - renorm_stddev) * (1. - renorm_momentum) + renorm_weight += (1. - renorm_weight) * (1. - renorm_momentum) + moving_mean += (renorm_mean / renorm_weight - + moving_mean) * (1. - momentum) + moving_variance += ((renorm_stddev / renorm_weight) ** 2 - epsilon - + moving_variance) * (1. - momentum) + + y_test = ((x - moving_mean) / (moving_variance + epsilon) ** 0.5 * + gamma) + beta + + yt_val_train, _, _ = sess.run([yt] + bn.updates, + feed_dict={xt: x, training: True}) + yt_val_test, _, _ = sess.run([yt] + bn.updates, + feed_dict={xt: x, training: False}) + + self.assertAllClose(y_train, yt_val_train, atol=1e-5) + self.assertAllClose(y_test, yt_val_test, atol=1e-5) + if __name__ == '__main__': test.main() |