aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/layers/python/layers/layers.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/layers/python/layers/layers.py')
-rw-r--r--tensorflow/contrib/layers/python/layers/layers.py13
1 files changed, 9 insertions, 4 deletions
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py
index dc4ee9226a..aee57dbeaf 100644
--- a/tensorflow/contrib/layers/python/layers/layers.py
+++ b/tensorflow/contrib/layers/python/layers/layers.py
@@ -117,6 +117,7 @@ def batch_norm(inputs,
scale=False,
epsilon=0.001,
activation_fn=None,
+ initializers={},
updates_collections=ops.GraphKeys.UPDATE_OPS,
is_training=True,
reuse=None,
@@ -211,39 +212,43 @@ def batch_norm(inputs,
if center:
beta_collections = utils.get_variable_collections(variables_collections,
'beta')
+ beta_initializer = initializers.get('beta', init_ops.zeros_initializer)
beta = variables.model_variable('beta',
shape=params_shape,
dtype=dtype,
- initializer=init_ops.zeros_initializer,
+ initializer=beta_initializer,
collections=beta_collections,
trainable=trainable)
if scale:
gamma_collections = utils.get_variable_collections(variables_collections,
'gamma')
+ gamma_initializer = initializers.get('gamma', init_ops.ones_initializer)
gamma = variables.model_variable('gamma',
shape=params_shape,
dtype=dtype,
- initializer=init_ops.ones_initializer,
+ initializer=gamma_initializer,
collections=gamma_collections,
trainable=trainable)
# Create moving_mean and moving_variance variables and add them to the
# appropiate collections.
moving_mean_collections = utils.get_variable_collections(
variables_collections, 'moving_mean')
+ moving_mean_initializer = initializers.get('moving_mean', init_ops.zeros_initializer)
moving_mean = variables.model_variable(
'moving_mean',
shape=params_shape,
dtype=dtype,
- initializer=init_ops.zeros_initializer,
+ initializer=moving_mean_initializer,
trainable=False,
collections=moving_mean_collections)
moving_variance_collections = utils.get_variable_collections(
variables_collections, 'moving_variance')
+ moving_variance_initializer = initializers.get('moving_variance', init_ops.ones_initializer)
moving_variance = variables.model_variable(
'moving_variance',
shape=params_shape,
dtype=dtype,
- initializer=init_ops.ones_initializer,
+ initializer=moving_variance_initializer,
trainable=False,
collections=moving_variance_collections)