diff options
-rw-r--r-- | tensorflow/contrib/layers/python/layers/layers.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index e5d16949c3..63b0443634 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -129,8 +129,8 @@ def batch_norm(inputs, Can be used as a normalizer function for conv2d and fully_connected. Args: - inputs: a tensor of size `[batch_size, height, width, channels]` - or `[batch_size, channels]`. + inputs: a tensor with 2 or more dimensions, where the first dimension has + `batch_size`. The normalization is over all but the last dimension. decay: decay for the moving average. center: If True, subtract `beta`. If False, `beta` is ignored. scale: If True, multiply by `gamma`. If False, `gamma` is @@ -220,7 +220,7 @@ def batch_norm(inputs, is_training_value = utils.constant_value(is_training) need_moments = is_training_value is None or is_training_value if need_moments: - # Calculate the moments based on the individual batch. + # Calculate the moments based on the individual batch. mean, variance = nn.moments(inputs, axis, shift=moving_mean) moving_vars_fn = lambda: (moving_mean, moving_variance) if updates_collections is None: |