aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/initializers.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/initializers.py')
-rw-r--r--tensorflow/python/keras/initializers.py100
1 files changed, 10 insertions, 90 deletions
diff --git a/tensorflow/python/keras/initializers.py b/tensorflow/python/keras/initializers.py
index 28beb6760d..b9d856efa8 100644
--- a/tensorflow/python/keras/initializers.py
+++ b/tensorflow/python/keras/initializers.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Keras initializer classes (soon to be replaced with core TF initializers).
+"""Keras initializer serialization / deserialization.
"""
from __future__ import absolute_import
from __future__ import division
@@ -22,107 +22,27 @@ import six
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
+
+# These imports are brought in so that keras.initializers.deserialize
+# has them available in module_objects.
from tensorflow.python.ops.init_ops import Constant
from tensorflow.python.ops.init_ops import glorot_normal_initializer
from tensorflow.python.ops.init_ops import glorot_uniform_initializer
-
+from tensorflow.python.ops.init_ops import he_normal # pylint: disable=unused-import
+from tensorflow.python.ops.init_ops import he_uniform # pylint: disable=unused-import
from tensorflow.python.ops.init_ops import Identity
from tensorflow.python.ops.init_ops import Initializer # pylint: disable=unused-import
+from tensorflow.python.ops.init_ops import lecun_normal # pylint: disable=unused-import
+from tensorflow.python.ops.init_ops import lecun_uniform # pylint: disable=unused-import
from tensorflow.python.ops.init_ops import Ones
from tensorflow.python.ops.init_ops import Orthogonal
from tensorflow.python.ops.init_ops import RandomNormal
from tensorflow.python.ops.init_ops import RandomUniform
from tensorflow.python.ops.init_ops import TruncatedNormal
-from tensorflow.python.ops.init_ops import VarianceScaling
+from tensorflow.python.ops.init_ops import VarianceScaling # pylint: disable=unused-import
from tensorflow.python.ops.init_ops import Zeros
-from tensorflow.python.util.tf_export import tf_export
-
-
-@tf_export('keras.initializers.lecun_normal')
-def lecun_normal(seed=None):
- """LeCun normal initializer.
-
- It draws samples from a truncated normal distribution centered on 0
- with `stddev = sqrt(1 / fan_in)`
- where `fan_in` is the number of input units in the weight tensor.
-
- Arguments:
- seed: A Python integer. Used to seed the random generator.
-
- Returns:
- An initializer.
-
- References:
- - [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)
- - [Efficient
- Backprop](http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf)
- """
- return VarianceScaling(
- scale=1., mode='fan_in', distribution='normal', seed=seed)
-
-
-@tf_export('keras.initializers.lecun_uniform')
-def lecun_uniform(seed=None):
- """LeCun uniform initializer.
-
- It draws samples from a uniform distribution within [-limit, limit]
- where `limit` is `sqrt(3 / fan_in)`
- where `fan_in` is the number of input units in the weight tensor.
-
- Arguments:
- seed: A Python integer. Used to seed the random generator.
-
- Returns:
- An initializer.
- References:
- LeCun 98, Efficient Backprop,
- http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf
- """
- return VarianceScaling(
- scale=1., mode='fan_in', distribution='uniform', seed=seed)
-
-
-@tf_export('keras.initializers.he_normal')
-def he_normal(seed=None):
- """He normal initializer.
-
- It draws samples from a truncated normal distribution centered on 0
- with `stddev = sqrt(2 / fan_in)`
- where `fan_in` is the number of input units in the weight tensor.
-
- Arguments:
- seed: A Python integer. Used to seed the random generator.
-
- Returns:
- An initializer.
-
- References:
- He et al., http://arxiv.org/abs/1502.01852
- """
- return VarianceScaling(
- scale=2., mode='fan_in', distribution='normal', seed=seed)
-
-
-@tf_export('keras.initializers.he_uniform')
-def he_uniform(seed=None):
- """He uniform variance scaling initializer.
-
- It draws samples from a uniform distribution within [-limit, limit]
- where `limit` is `sqrt(6 / fan_in)`
- where `fan_in` is the number of input units in the weight tensor.
-
- Arguments:
- seed: A Python integer. Used to seed the random generator.
-
- Returns:
- An initializer.
-
- References:
- He et al., http://arxiv.org/abs/1502.01852
- """
- return VarianceScaling(
- scale=2., mode='fan_in', distribution='uniform', seed=seed)
+from tensorflow.python.util.tf_export import tf_export
# Compatibility aliases