diff options
Diffstat (limited to 'tensorflow/core/kernels/batch_norm_op.h')
-rw-r--r-- | tensorflow/core/kernels/batch_norm_op.h | 133 |
1 files changed, 133 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/batch_norm_op.h b/tensorflow/core/kernels/batch_norm_op.h new file mode 100644 index 0000000000..5981e58460 --- /dev/null +++ b/tensorflow/core/kernels/batch_norm_op.h @@ -0,0 +1,133 @@ +#ifndef TENSORFLOW_KERNELS_BATCH_NORM_OP_H_ +#define TENSORFLOW_KERNELS_BATCH_NORM_OP_H_ +// Functor definition for BatchNormOp, must be compilable by nvcc. +#include "tensorflow/core/framework/tensor_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +namespace functor { + +// Functor used by BatchNormOp to do the computations. +template <typename Device, typename T> +struct BatchNorm { + void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor input, + typename TTypes<T>::ConstVec mean, + typename TTypes<T>::ConstVec var, + typename TTypes<T>::ConstVec beta, + typename TTypes<T>::ConstVec gamma, float variance_epsilon, + bool scale_after_normalization, + typename TTypes<T, 4>::Tensor output) { + const int depth = mean.dimension(0); + const int rest_size = input.size() / depth; + + Eigen::DSizes<int, 2> rest_by_depth(rest_size, depth); +#if !defined(EIGEN_HAS_INDEX_LIST) + Eigen::DSizes<int, 2> rest_by_one(rest_size, 1); + Eigen::DSizes<int, 2> one_by_depth(1, depth); + Eigen::DSizes<int, 2> depth_by_one(depth, 1); +#else + Eigen::IndexList<int, Eigen::type2index<1> > rest_by_one; + rest_by_one.set(0, rest_size); + Eigen::IndexList<Eigen::type2index<1>, int> one_by_depth; + one_by_depth.set(1, depth); + Eigen::IndexList<int, Eigen::type2index<1> > depth_by_one; + depth_by_one.set(0, depth); +#endif + if (scale_after_normalization) { + output.reshape(rest_by_depth).device(d) = + (input.reshape(rest_by_depth) - + mean.reshape(one_by_depth).broadcast(rest_by_one)) * + ((var + var.constant(variance_epsilon)).rsqrt() * gamma) + .eval() + .reshape(one_by_depth) + .broadcast(rest_by_one) + + beta.reshape(one_by_depth).broadcast(rest_by_one); + } else { + output.reshape(rest_by_depth).device(d) = + (input.reshape(rest_by_depth) - + mean.reshape(one_by_depth).broadcast(rest_by_one)) * + ((var + var.constant(variance_epsilon)).rsqrt()) + .eval() + .reshape(one_by_depth) + .broadcast(rest_by_one) + + beta.reshape(one_by_depth).broadcast(rest_by_one); + } + } +}; + +template <typename Device, typename T> +struct BatchNormGrad { + void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor input, + typename TTypes<T>::ConstVec mean, + typename TTypes<T>::ConstVec var, + typename TTypes<T>::ConstVec gamma, + typename TTypes<T, 4>::ConstTensor out_backprop, + float variance_epsilon, bool scale_after_normalization, + typename TTypes<T, 4>::Tensor dx, typename TTypes<T>::Vec dm, + typename TTypes<T>::Vec dv, typename TTypes<T>::Vec db, + typename TTypes<T>::Vec dg, typename TTypes<T>::Vec scratch1, + typename TTypes<T>::Vec scratch2) { + const int depth = mean.dimension(0); + const int rest_size = input.size() / depth; + + typedef typename TTypes<T>::ConstVec::Index Index; + Eigen::DSizes<Index, 2> rest_by_depth(rest_size, depth); + Eigen::DSizes<Index, 2> rest_by_one(rest_size, 1); + Eigen::DSizes<Index, 2> one_by_depth(1, depth); + + // db = out_backprop + // + // dg = out_backprop * ((x - m) * rsqrt(v + epsilon)) + // + // dv = sum_over_rest(out_backprop * gamma * (x - m)) * + // (-1/2) * (v + epsilon) ^ (-3/2) + // + // dm = sum_over_rest(out_backprop * gamma) * (-1 / rsqrt(v + epsilon)) + // + // dx = out_backprop * (gamma * rsqrt(v + epsilon)) + Eigen::array<Index, 1> reduction_axis; + reduction_axis[0] = 0; // Reduces on first dimension. + + db.device(d) = out_backprop.reshape(rest_by_depth).sum(reduction_axis); + + // scratch1 = rsqrt(v + epsilon) + scratch1.device(d) = (var + var.constant(variance_epsilon)).rsqrt(); + + // scratch2 = sum_over_rest(out_backprop * (x - m)) + scratch2.device(d) = (out_backprop.reshape(rest_by_depth) * + (input.reshape(rest_by_depth) - + mean.reshape(one_by_depth).broadcast(rest_by_one))) + .sum(reduction_axis); + + if (scale_after_normalization) { + dx.reshape(rest_by_depth).device(d) = + out_backprop.reshape(rest_by_depth) * ((scratch1 * gamma) + .eval() + .reshape(one_by_depth) + .broadcast(rest_by_one)); + dm.device(d) = -db * (scratch1 * gamma).eval(); + dg.device(d) = scratch2 * scratch1; + } else { + dx.reshape(rest_by_depth).device(d) = + out_backprop.reshape(rest_by_depth) * + scratch1.reshape(one_by_depth).broadcast(rest_by_one); + dm.device(d) = -db * scratch1; + dg.device(d) = dg.constant(static_cast<T>(0.0)); // Gamma is not learned. + } + + // scratch1 = - 1/2 * (var + epsilon) ^ (-3/2) + scratch1.device(d) = scratch1 * scratch1.constant(static_cast<T>(-0.5f)) / + (var + var.constant(variance_epsilon)); + + if (scale_after_normalization) { + dv.device(d) = scratch2 * (scratch1 * gamma).eval(); + } else { + dv.device(d) = scratch2 * scratch1; + } + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_BATCH_NORM_OP_H_ |