aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/layers/python/layers/layers_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/layers/python/layers/layers_test.py')
-rw-r--r--tensorflow/contrib/layers/python/layers/layers_test.py23
1 files changed, 23 insertions, 0 deletions
diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py
index c67702fe98..b40a8936c7 100644
--- a/tensorflow/contrib/layers/python/layers/layers_test.py
+++ b/tensorflow/contrib/layers/python/layers/layers_test.py
@@ -1639,6 +1639,29 @@ class BatchNormTest(tf.test.TestCase):
self.assertAllClose(moving_mean.eval(), expected_mean)
self.assertAllClose(moving_variance.eval(), expected_var)
+ def testCustomInitializer(self):
+ height, width = 3, 3
+ channels = 3
+ with self.test_session() as sess:
+ images = np.ones((5, height, width, channels))*9.0
+ beta = tf.constant_initializer(np.ones(channels)*5.0)
+ gamma = tf.constant_initializer(np.ones(channels)*2.0)
+ mean = tf.constant_initializer(np.ones(channels)*5.0)
+ variance = tf.constant_initializer(np.ones(channels)*4.0)
+ output = tf.contrib.layers.batch_norm(images,
+ is_training=False,
+ scale=True,
+ epsilon=0.0,
+ initializers={
+ 'beta': beta,
+ 'gamma': gamma,
+ 'moving_mean': mean,
+ 'moving_variance': variance,
+ })
+ sess.run(tf.initialize_all_variables())
+ outs = sess.run(output)
+ self.assertAllClose(outs, images)
+
class LayerNormTest(tf.test.TestCase):