aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/init_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/init_ops.py')
-rw-r--r--tensorflow/python/ops/init_ops.py99
1 files changed, 96 insertions, 3 deletions
diff --git a/tensorflow/python/ops/init_ops.py b/tensorflow/python/ops/init_ops.py
index 5bfc5ce2a7..c315722b6b 100644
--- a/tensorflow/python/ops/init_ops.py
+++ b/tensorflow/python/ops/init_ops.py
@@ -1136,7 +1136,8 @@ convolutional_orthogonal_3d = ConvolutionOrthogonal3D
# pylint: enable=invalid-name
-@tf_export("glorot_uniform_initializer")
+@tf_export("glorot_uniform_initializer", "keras.initializers.glorot_uniform",
+ "initializers.glorot_uniform")
def glorot_uniform_initializer(seed=None, dtype=dtypes.float32):
"""The Glorot uniform initializer, also called Xavier uniform initializer.
@@ -1160,7 +1161,8 @@ def glorot_uniform_initializer(seed=None, dtype=dtypes.float32):
scale=1.0, mode="fan_avg", distribution="uniform", seed=seed, dtype=dtype)
-@tf_export("glorot_normal_initializer")
+@tf_export("glorot_normal_initializer", "keras.initializers.glorot_normal",
+ "initializers.glorot_normal")
def glorot_normal_initializer(seed=None, dtype=dtypes.float32):
"""The Glorot normal initializer, also called Xavier normal initializer.
@@ -1181,7 +1183,98 @@ def glorot_normal_initializer(seed=None, dtype=dtypes.float32):
An initializer.
"""
return variance_scaling_initializer(
- scale=1.0, mode="fan_avg", distribution="normal", seed=seed, dtype=dtype)
+ scale=1.0,
+ mode="fan_avg",
+ distribution="truncated_normal",
+ seed=seed,
+ dtype=dtype)
+
+
+@tf_export("keras.initializers.lecun_normal", "initializers.lecun_normal")
+def lecun_normal(seed=None):
+ """LeCun normal initializer.
+
+ It draws samples from a truncated normal distribution centered on 0
+ with `stddev = sqrt(1 / fan_in)`
+ where `fan_in` is the number of input units in the weight tensor.
+
+ Arguments:
+ seed: A Python integer. Used to seed the random generator.
+
+ Returns:
+ An initializer.
+
+ References:
+ - [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)
+ - [Efficient
+ Backprop](http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf)
+ """
+ return VarianceScaling(
+ scale=1., mode="fan_in", distribution="truncated_normal", seed=seed)
+
+
+@tf_export("keras.initializers.lecun_uniform", "initializers.lecun_uniform")
+def lecun_uniform(seed=None):
+ """LeCun uniform initializer.
+
+ It draws samples from a uniform distribution within [-limit, limit]
+ where `limit` is `sqrt(3 / fan_in)`
+ where `fan_in` is the number of input units in the weight tensor.
+
+ Arguments:
+ seed: A Python integer. Used to seed the random generator.
+
+ Returns:
+ An initializer.
+
+ References:
+ LeCun 98, Efficient Backprop,
+ http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf
+ """
+ return VarianceScaling(
+ scale=1., mode="fan_in", distribution="uniform", seed=seed)
+
+
+@tf_export("keras.initializers.he_normal", "initializers.he_normal")
+def he_normal(seed=None):
+ """He normal initializer.
+
+ It draws samples from a truncated normal distribution centered on 0
+ with `stddev = sqrt(2 / fan_in)`
+ where `fan_in` is the number of input units in the weight tensor.
+
+ Arguments:
+ seed: A Python integer. Used to seed the random generator.
+
+ Returns:
+ An initializer.
+
+ References:
+ He et al., http://arxiv.org/abs/1502.01852
+ """
+ return VarianceScaling(
+ scale=2., mode="fan_in", distribution="truncated_normal", seed=seed)
+
+
+@tf_export("keras.initializers.he_uniform", "initializers.he_uniform")
+def he_uniform(seed=None):
+ """He uniform variance scaling initializer.
+
+ It draws samples from a uniform distribution within [-limit, limit]
+ where `limit` is `sqrt(6 / fan_in)`
+ where `fan_in` is the number of input units in the weight tensor.
+
+ Arguments:
+ seed: A Python integer. Used to seed the random generator.
+
+ Returns:
+ An initializer.
+
+ References:
+ He et al., http://arxiv.org/abs/1502.01852
+ """
+ return VarianceScaling(
+ scale=2., mode="fan_in", distribution="uniform", seed=seed)
# Utility functions.