diff options
Diffstat (limited to 'tensorflow/contrib/layers/python/layers/layers.py')
-rw-r--r-- | tensorflow/contrib/layers/python/layers/layers.py | 13 |
1 files changed, 9 insertions, 4 deletions
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index dc4ee9226a..aee57dbeaf 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -117,6 +117,7 @@ def batch_norm(inputs, scale=False, epsilon=0.001, activation_fn=None, + initializers={}, updates_collections=ops.GraphKeys.UPDATE_OPS, is_training=True, reuse=None, @@ -211,39 +212,43 @@ def batch_norm(inputs, if center: beta_collections = utils.get_variable_collections(variables_collections, 'beta') + beta_initializer = initializers.get('beta', init_ops.zeros_initializer) beta = variables.model_variable('beta', shape=params_shape, dtype=dtype, - initializer=init_ops.zeros_initializer, + initializer=beta_initializer, collections=beta_collections, trainable=trainable) if scale: gamma_collections = utils.get_variable_collections(variables_collections, 'gamma') + gamma_initializer = initializers.get('gamma', init_ops.ones_initializer) gamma = variables.model_variable('gamma', shape=params_shape, dtype=dtype, - initializer=init_ops.ones_initializer, + initializer=gamma_initializer, collections=gamma_collections, trainable=trainable) # Create moving_mean and moving_variance variables and add them to the # appropiate collections. moving_mean_collections = utils.get_variable_collections( variables_collections, 'moving_mean') + moving_mean_initializer = initializers.get('moving_mean', init_ops.zeros_initializer) moving_mean = variables.model_variable( 'moving_mean', shape=params_shape, dtype=dtype, - initializer=init_ops.zeros_initializer, + initializer=moving_mean_initializer, trainable=False, collections=moving_mean_collections) moving_variance_collections = utils.get_variable_collections( variables_collections, 'moving_variance') + moving_variance_initializer = initializers.get('moving_variance', init_ops.ones_initializer) moving_variance = variables.model_variable( 'moving_variance', shape=params_shape, dtype=dtype, - initializer=init_ops.ones_initializer, + initializer=moving_variance_initializer, trainable=False, collections=moving_variance_collections) |