aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/layers/python/layers/layers.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py
index e5d16949c3..63b0443634 100644
--- a/tensorflow/contrib/layers/python/layers/layers.py
+++ b/tensorflow/contrib/layers/python/layers/layers.py
@@ -129,8 +129,8 @@ def batch_norm(inputs,
Can be used as a normalizer function for conv2d and fully_connected.
Args:
- inputs: a tensor of size `[batch_size, height, width, channels]`
- or `[batch_size, channels]`.
+ inputs: a tensor with 2 or more dimensions, where the first dimension has
+ `batch_size`. The normalization is over all but the last dimension.
decay: decay for the moving average.
center: If True, subtract `beta`. If False, `beta` is ignored.
scale: If True, multiply by `gamma`. If False, `gamma` is
@@ -220,7 +220,7 @@ def batch_norm(inputs,
is_training_value = utils.constant_value(is_training)
need_moments = is_training_value is None or is_training_value
if need_moments:
- # Calculate the moments based on the individual batch.
+ # Calculate the moments based on the individual batch.
mean, variance = nn.moments(inputs, axis, shift=moving_mean)
moving_vars_fn = lambda: (moving_mean, moving_variance)
if updates_collections is None: