aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/layers/normalization_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/layers/normalization_test.py')
-rw-r--r--tensorflow/python/layers/normalization_test.py59
1 files changed, 59 insertions, 0 deletions
diff --git a/tensorflow/python/layers/normalization_test.py b/tensorflow/python/layers/normalization_test.py
index 91b7cb6f48..0f82f73ea4 100644
--- a/tensorflow/python/layers/normalization_test.py
+++ b/tensorflow/python/layers/normalization_test.py
@@ -24,6 +24,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.layers import normalization as normalization_layers
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
@@ -513,6 +514,64 @@ class BNTest(test.TestCase):
_ = bn.apply(inputs, training=training)
self.assertEqual(len(bn.losses), 1)
+ def testRenorm(self):
+ shape = (4, 3)
+ xt = array_ops.placeholder(dtypes.float32, shape)
+ momentum = 0.99
+ renorm_momentum = 0.8
+ rmax = 1.1
+ rmin = 0.9
+ dmax = 0.1
+ gamma = 2.
+ beta = 3.
+ epsilon = 0.001
+ bn = normalization_layers.BatchNormalization(
+ axis=1,
+ gamma_initializer=init_ops.constant_initializer(gamma),
+ beta_initializer=init_ops.constant_initializer(beta),
+ epsilon=epsilon,
+ momentum=momentum,
+ renorm=True,
+ renorm_clipping={'rmax': rmax, 'rmin': rmin, 'dmax': dmax},
+ renorm_momentum=renorm_momentum)
+ training = array_ops.placeholder(dtypes.bool)
+ yt = bn.apply(xt, training=training)
+
+ moving_mean = 0.
+ moving_variance = 1.
+ renorm_mean = renorm_stddev = 0.
+ renorm_weight = 0.
+ with self.test_session(use_gpu=True) as sess:
+ sess.run(variables.global_variables_initializer())
+ for _ in range(5):
+ x = np.random.random(shape)
+
+ mean = x.mean(0)
+ stddev = np.sqrt(x.var(0) + epsilon)
+ adj_mean = renorm_mean + (1. - renorm_weight) * mean
+ adj_stddev = renorm_stddev + (1. - renorm_weight) * stddev
+ r = (stddev / adj_stddev).clip(rmin, rmax)
+ d = ((mean - adj_mean) / adj_stddev).clip(-dmax, dmax)
+ y_train = ((x - mean) / stddev * r + d) * gamma + beta
+ renorm_mean += (mean - renorm_mean) * (1. - renorm_momentum)
+ renorm_stddev += (stddev - renorm_stddev) * (1. - renorm_momentum)
+ renorm_weight += (1. - renorm_weight) * (1. - renorm_momentum)
+ moving_mean += (renorm_mean / renorm_weight -
+ moving_mean) * (1. - momentum)
+ moving_variance += ((renorm_stddev / renorm_weight) ** 2 - epsilon -
+ moving_variance) * (1. - momentum)
+
+ y_test = ((x - moving_mean) / (moving_variance + epsilon) ** 0.5 *
+ gamma) + beta
+
+ yt_val_train, _, _ = sess.run([yt] + bn.updates,
+ feed_dict={xt: x, training: True})
+ yt_val_test, _, _ = sess.run([yt] + bn.updates,
+ feed_dict={xt: x, training: False})
+
+ self.assertAllClose(y_train, yt_val_train, atol=1e-5)
+ self.assertAllClose(y_test, yt_val_test, atol=1e-5)
+
if __name__ == '__main__':
test.main()