aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/learn/python/learn/ops/batch_norm_ops.py
blob: de7a315657936be1984944de63fdaee6ec9a18e4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""TensorFlow ops for Batch Normalization."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops as array_ops_
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.training import moving_averages


def batch_normalize(tensor_in,
                    epsilon=1e-5,
                    convnet=False,
                    decay=0.9,
                    scale_after_normalization=True):
  """Batch normalization.

  Args:
    tensor_in: input `Tensor`, 4D shape: [batch, in_height, in_width, in_depth].
    epsilon : A float number to avoid being divided by 0.
    convnet: Whether this is for convolutional net use. If `True`, moments
        will sum across axis `[0, 1, 2]`. Otherwise, only `[0]`.
    decay: Decay rate for exponential moving average.
    scale_after_normalization: Whether to scale after normalization.

  Returns:
    A batch-normalized `Tensor`.
  """
  shape = tensor_in.get_shape().as_list()

  with vs.variable_scope("batch_norm"):
    gamma = vs.get_variable(
        "gamma", [shape[-1]],
        initializer=init_ops.random_normal_initializer(1., 0.02))
    beta = vs.get_variable("beta", [shape[-1]],
                           initializer=init_ops.constant_initializer(0.))
    moving_mean = vs.get_variable(
        'moving_mean',
        shape=[shape[-1]],
        initializer=init_ops.zeros_initializer,
        trainable=False)
    moving_var = vs.get_variable(
        'moving_var',
        shape=[shape[-1]],
        initializer=init_ops.ones_initializer,
        trainable=False)

    def _update_mean_var():
      """Internal function that updates mean and variance during training."""
      axis = [0, 1, 2] if convnet else [0]
      mean, var = nn.moments(tensor_in, axis)
      update_moving_mean = moving_averages.assign_moving_average(
          moving_mean, mean, decay)
      update_moving_var = moving_averages.assign_moving_average(
          moving_var, var, decay)
      with ops.control_dependencies([update_moving_mean, update_moving_var]):
        return array_ops_.identity(mean), array_ops_.identity(var)

    is_training = array_ops_.squeeze(ops.get_collection("IS_TRAINING"))
    mean, variance = control_flow_ops.cond(is_training, _update_mean_var,
                                           lambda: (moving_mean, moving_var))
    return nn.batch_norm_with_global_normalization(
        tensor_in,
        mean,
        variance,
        beta,
        gamma,
        epsilon,
        scale_after_normalization=scale_after_normalization)