aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/init_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/init_ops_test.py')
-rw-r--r--tensorflow/python/kernel_tests/init_ops_test.py27
1 files changed, 27 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/init_ops_test.py b/tensorflow/python/kernel_tests/init_ops_test.py
index a9b55854f1..795aa67248 100644
--- a/tensorflow/python/kernel_tests/init_ops_test.py
+++ b/tensorflow/python/kernel_tests/init_ops_test.py
@@ -362,6 +362,33 @@ class UniformUnitScalingInitializationTest(test.TestCase):
dtype=dtypes.string)
+class VarianceScalingInitializationTest(test.TestCase):
+
+ def testNormalDistribution(self):
+ shape = [100, 100]
+ expect_mean = 0.
+ expect_var = 1. / shape[0]
+ init = init_ops.variance_scaling_initializer(distribution='normal')
+
+ with self.test_session(use_gpu=True):
+ x = init(shape).eval()
+
+ self.assertNear(np.mean(x), expect_mean, err=1e-2)
+ self.assertNear(np.var(x), expect_var, err=1e-2)
+
+ def testUniformDistribution(self):
+ shape = [100, 100]
+ expect_mean = 0.
+ expect_var = 1. / shape[0]
+ init = init_ops.variance_scaling_initializer(distribution='uniform')
+
+ with self.test_session(use_gpu=True):
+ x = init(shape).eval()
+
+ self.assertNear(np.mean(x), expect_mean, err=1e-2)
+ self.assertNear(np.var(x), expect_var, err=1e-2)
+
+
# TODO(vrv): move to sequence_ops_test?
class RangeTest(test.TestCase):