aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/batch_norm_op.h
blob: f1c6e47d14b57496bbd6510c5ba1192a43397b1a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#ifndef TENSORFLOW_KERNELS_BATCH_NORM_OP_H_
#define TENSORFLOW_KERNELS_BATCH_NORM_OP_H_
// Functor definition for BatchNormOp, must be compilable by nvcc.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"

namespace tensorflow {
namespace functor {

// Functor used by BatchNormOp to do the computations.
template <typename Device, typename T>
struct BatchNorm {
  void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor input,
                  typename TTypes<T>::ConstVec mean,
                  typename TTypes<T>::ConstVec var,
                  typename TTypes<T>::ConstVec beta,
                  typename TTypes<T>::ConstVec gamma, float variance_epsilon,
                  bool scale_after_normalization,
                  typename TTypes<T, 4>::Tensor output) {
    const int depth = mean.dimension(0);
    const int rest_size = input.size() / depth;

    Eigen::DSizes<int, 2> rest_by_depth(rest_size, depth);
#if !defined(EIGEN_HAS_INDEX_LIST)
    Eigen::DSizes<int, 2> rest_by_one(rest_size, 1);
    Eigen::DSizes<int, 2> one_by_depth(1, depth);
    Eigen::DSizes<int, 2> depth_by_one(depth, 1);
#else
    Eigen::IndexList<int, Eigen::type2index<1> > rest_by_one;
    rest_by_one.set(0, rest_size);
    Eigen::IndexList<Eigen::type2index<1>, int> one_by_depth;
    one_by_depth.set(1, depth);
    Eigen::IndexList<int, Eigen::type2index<1> > depth_by_one;
    depth_by_one.set(0, depth);
#endif
    if (scale_after_normalization) {
      output.reshape(rest_by_depth).device(d) =
          (input.reshape(rest_by_depth) -
           mean.reshape(one_by_depth).broadcast(rest_by_one)) *
              ((var + var.constant(variance_epsilon)).rsqrt() * gamma)
                  .eval()
                  .reshape(one_by_depth)
                  .broadcast(rest_by_one) +
          beta.reshape(one_by_depth).broadcast(rest_by_one);
    } else {
      output.reshape(rest_by_depth).device(d) =
          (input.reshape(rest_by_depth) -
           mean.reshape(one_by_depth).broadcast(rest_by_one)) *
              ((var + var.constant(variance_epsilon)).rsqrt())
                  .eval()
                  .reshape(one_by_depth)
                  .broadcast(rest_by_one) +
          beta.reshape(one_by_depth).broadcast(rest_by_one);
    }
  }
};

template <typename Device, typename T>
struct BatchNormGrad {
  void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor input,
                  typename TTypes<T>::ConstVec mean,
                  typename TTypes<T>::ConstVec var,
                  typename TTypes<T>::ConstVec gamma,
                  typename TTypes<T, 4>::ConstTensor out_backprop,
                  float variance_epsilon, bool scale_after_normalization,
                  typename TTypes<T, 4>::Tensor dx, typename TTypes<T>::Vec dm,
                  typename TTypes<T>::Vec dv, typename TTypes<T>::Vec db,
                  typename TTypes<T>::Vec dg, typename TTypes<T>::Vec scratch1,
                  typename TTypes<T>::Vec scratch2) {
    const int depth = mean.dimension(0);
    const int rest_size = input.size() / depth;

    typedef typename TTypes<T>::ConstVec::Index Index;
    Eigen::DSizes<Index, 2> rest_by_depth(rest_size, depth);
    Eigen::DSizes<Index, 2> rest_by_one(rest_size, 1);
    Eigen::DSizes<Index, 2> one_by_depth(1, depth);

    // db = out_backprop
    //
    // dg = out_backprop * ((x - m) * rsqrt(v + epsilon))
    //
    // dv = sum_over_rest(out_backprop * gamma * (x - m)) *
    //      (-1/2) * (v + epsilon) ^ (-3/2)
    //
    // dm = sum_over_rest(out_backprop * gamma) * (-1 / rsqrt(v + epsilon))
    //
    // dx = out_backprop * (gamma * rsqrt(v + epsilon))
    Eigen::array<Index, 1> reduction_axis;
    reduction_axis[0] = 0;  // Reduces on first dimension.

    db.device(d) = out_backprop.reshape(rest_by_depth).sum(reduction_axis);

    // scratch1 = rsqrt(v + epsilon)
    scratch1.device(d) = (var + var.constant(variance_epsilon)).rsqrt();

    // scratch2 = sum_over_rest(out_backprop * (x - m))
    scratch2.device(d) = (out_backprop.reshape(rest_by_depth) *
                          (input.reshape(rest_by_depth) -
                           mean.reshape(one_by_depth).broadcast(rest_by_one)))
                             .sum(reduction_axis);

    if (scale_after_normalization) {
      dx.reshape(rest_by_depth).device(d) =
          out_backprop.reshape(rest_by_depth) * ((scratch1 * gamma)
                                                     .eval()
                                                     .reshape(one_by_depth)
                                                     .broadcast(rest_by_one));
      dm.device(d) = -db * (scratch1 * gamma).eval();
      dg.device(d) = scratch2 * scratch1;
    } else {
      dx.reshape(rest_by_depth).device(d) =
          out_backprop.reshape(rest_by_depth) *
          scratch1.reshape(one_by_depth).broadcast(rest_by_one);
      dm.device(d) = -db * scratch1;
      dg.device(d) = dg.constant(static_cast<T>(0.0));  // Gamma is not learned.
    }

    // scratch1 = - 1/2 * (var + epsilon) ^ (-3/2)
    scratch1.device(d) = scratch1 * scratch1.constant(static_cast<T>(-0.5f)) /
                         (var + var.constant(variance_epsilon));

    if (scale_after_normalization) {
      dv.device(d) = scratch2 * (scratch1 * gamma).eval();
    } else {
      dv.device(d) = scratch2 * scratch1;
    }
  }
};

}  // namespace functor
}  // namespace tensorflow

#endif  // TENSORFLOW_KERNELS_BATCH_NORM_OP_H_