diff options
Diffstat (limited to 'tensorflow/python/keras/initializers.py')
-rw-r--r-- | tensorflow/python/keras/initializers.py | 100 |
1 files changed, 10 insertions, 90 deletions
diff --git a/tensorflow/python/keras/initializers.py b/tensorflow/python/keras/initializers.py index 28beb6760d..b9d856efa8 100644 --- a/tensorflow/python/keras/initializers.py +++ b/tensorflow/python/keras/initializers.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Keras initializer classes (soon to be replaced with core TF initializers). +"""Keras initializer serialization / deserialization. """ from __future__ import absolute_import from __future__ import division @@ -22,107 +22,27 @@ import six from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object from tensorflow.python.keras.utils.generic_utils import serialize_keras_object + +# These imports are brought in so that keras.initializers.deserialize +# has them available in module_objects. from tensorflow.python.ops.init_ops import Constant from tensorflow.python.ops.init_ops import glorot_normal_initializer from tensorflow.python.ops.init_ops import glorot_uniform_initializer - +from tensorflow.python.ops.init_ops import he_normal # pylint: disable=unused-import +from tensorflow.python.ops.init_ops import he_uniform # pylint: disable=unused-import from tensorflow.python.ops.init_ops import Identity from tensorflow.python.ops.init_ops import Initializer # pylint: disable=unused-import +from tensorflow.python.ops.init_ops import lecun_normal # pylint: disable=unused-import +from tensorflow.python.ops.init_ops import lecun_uniform # pylint: disable=unused-import from tensorflow.python.ops.init_ops import Ones from tensorflow.python.ops.init_ops import Orthogonal from tensorflow.python.ops.init_ops import RandomNormal from tensorflow.python.ops.init_ops import RandomUniform from tensorflow.python.ops.init_ops import TruncatedNormal -from tensorflow.python.ops.init_ops import VarianceScaling +from tensorflow.python.ops.init_ops import VarianceScaling # pylint: disable=unused-import from tensorflow.python.ops.init_ops import Zeros -from tensorflow.python.util.tf_export import tf_export - - -@tf_export('keras.initializers.lecun_normal') -def lecun_normal(seed=None): - """LeCun normal initializer. - - It draws samples from a truncated normal distribution centered on 0 - with `stddev = sqrt(1 / fan_in)` - where `fan_in` is the number of input units in the weight tensor. - - Arguments: - seed: A Python integer. Used to seed the random generator. - - Returns: - An initializer. - - References: - - [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515) - - [Efficient - Backprop](http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf) - """ - return VarianceScaling( - scale=1., mode='fan_in', distribution='normal', seed=seed) - - -@tf_export('keras.initializers.lecun_uniform') -def lecun_uniform(seed=None): - """LeCun uniform initializer. - - It draws samples from a uniform distribution within [-limit, limit] - where `limit` is `sqrt(3 / fan_in)` - where `fan_in` is the number of input units in the weight tensor. - - Arguments: - seed: A Python integer. Used to seed the random generator. - - Returns: - An initializer. - References: - LeCun 98, Efficient Backprop, - http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf - """ - return VarianceScaling( - scale=1., mode='fan_in', distribution='uniform', seed=seed) - - -@tf_export('keras.initializers.he_normal') -def he_normal(seed=None): - """He normal initializer. - - It draws samples from a truncated normal distribution centered on 0 - with `stddev = sqrt(2 / fan_in)` - where `fan_in` is the number of input units in the weight tensor. - - Arguments: - seed: A Python integer. Used to seed the random generator. - - Returns: - An initializer. - - References: - He et al., http://arxiv.org/abs/1502.01852 - """ - return VarianceScaling( - scale=2., mode='fan_in', distribution='normal', seed=seed) - - -@tf_export('keras.initializers.he_uniform') -def he_uniform(seed=None): - """He uniform variance scaling initializer. - - It draws samples from a uniform distribution within [-limit, limit] - where `limit` is `sqrt(6 / fan_in)` - where `fan_in` is the number of input units in the weight tensor. - - Arguments: - seed: A Python integer. Used to seed the random generator. - - Returns: - An initializer. - - References: - He et al., http://arxiv.org/abs/1502.01852 - """ - return VarianceScaling( - scale=2., mode='fan_in', distribution='uniform', seed=seed) +from tensorflow.python.util.tf_export import tf_export # Compatibility aliases |