diff options
Diffstat (limited to 'tensorflow/contrib/layers/python/layers/layers_test.py')
-rw-r--r-- | tensorflow/contrib/layers/python/layers/layers_test.py | 23 |
1 files changed, 23 insertions, 0 deletions
diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index c67702fe98..b40a8936c7 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -1639,6 +1639,29 @@ class BatchNormTest(tf.test.TestCase): self.assertAllClose(moving_mean.eval(), expected_mean) self.assertAllClose(moving_variance.eval(), expected_var) + def testCustomInitializer(self): + height, width = 3, 3 + channels = 3 + with self.test_session() as sess: + images = np.ones((5, height, width, channels))*9.0 + beta = tf.constant_initializer(np.ones(channels)*5.0) + gamma = tf.constant_initializer(np.ones(channels)*2.0) + mean = tf.constant_initializer(np.ones(channels)*5.0) + variance = tf.constant_initializer(np.ones(channels)*4.0) + output = tf.contrib.layers.batch_norm(images, + is_training=False, + scale=True, + epsilon=0.0, + initializers={ + 'beta': beta, + 'gamma': gamma, + 'moving_mean': mean, + 'moving_variance': variance, + }) + sess.run(tf.initialize_all_variables()) + outs = sess.run(output) + self.assertAllClose(outs, images) + class LayerNormTest(tf.test.TestCase): |