/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #ifndef TENSORFLOW_CORE_KERNELS_BATCH_NORM_OP_H_ #define TENSORFLOW_CORE_KERNELS_BATCH_NORM_OP_H_ // Functor definition for BatchNormOp, must be compilable by nvcc. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/tensor_types.h" namespace tensorflow { namespace functor { // Functor used by BatchNormOp to do the computations. template struct BatchNorm { void operator()(const Device& d, typename TTypes::ConstTensor input, typename TTypes::ConstVec mean, typename TTypes::ConstVec var, typename TTypes::ConstVec beta, typename TTypes::ConstVec gamma, T variance_epsilon, bool scale_after_normalization, typename TTypes::Tensor output) { const int depth = mean.dimension(0); const int rest_size = input.size() / depth; Eigen::DSizes rest_by_depth(rest_size, depth); #if !defined(EIGEN_HAS_INDEX_LIST) Eigen::DSizes rest_by_one(rest_size, 1); Eigen::DSizes one_by_depth(1, depth); Eigen::DSizes depth_by_one(depth, 1); #else Eigen::IndexList > rest_by_one; rest_by_one.set(0, rest_size); Eigen::IndexList, int> one_by_depth; one_by_depth.set(1, depth); Eigen::IndexList > depth_by_one; depth_by_one.set(0, depth); #endif if (scale_after_normalization) { output.reshape(rest_by_depth).device(d) = (input.reshape(rest_by_depth) - mean.reshape(one_by_depth).broadcast(rest_by_one)) * ((var + var.constant(variance_epsilon)).rsqrt() * gamma) .eval() .reshape(one_by_depth) .broadcast(rest_by_one) + beta.reshape(one_by_depth).broadcast(rest_by_one); } else { output.reshape(rest_by_depth).device(d) = (input.reshape(rest_by_depth) - mean.reshape(one_by_depth).broadcast(rest_by_one)) * ((var + var.constant(variance_epsilon)).rsqrt()) .eval() .reshape(one_by_depth) .broadcast(rest_by_one) + beta.reshape(one_by_depth).broadcast(rest_by_one); } } }; template struct BatchNormGrad { void operator()(const Device& d, typename TTypes::ConstTensor input, typename TTypes::ConstVec mean, typename TTypes::ConstVec var, typename TTypes::ConstVec gamma, typename TTypes::ConstTensor out_backprop, T variance_epsilon, bool scale_after_normalization, typename TTypes::Tensor dx, typename TTypes::Vec dm, typename TTypes::Vec dv, typename TTypes::Vec db, typename TTypes::Vec dg, typename TTypes::Vec scratch1, typename TTypes::Vec scratch2) { const int depth = mean.dimension(0); const int rest_size = input.size() / depth; typedef typename TTypes::ConstVec::Index Index; Eigen::DSizes rest_by_depth(rest_size, depth); #if !defined(EIGEN_HAS_INDEX_LIST) Eigen::DSizes rest_by_one(rest_size, 1); Eigen::DSizes one_by_depth(1, depth); Eigen::array reduction_axis; reduction_axis[0] = 0; // Reduces on first dimension. #else Eigen::IndexList > rest_by_one; rest_by_one.set(0, rest_size); Eigen::IndexList, Index> one_by_depth; one_by_depth.set(1, depth); Eigen::IndexList > reduction_axis; #endif // db = out_backprop // // dg = out_backprop * ((x - m) * rsqrt(v + epsilon)) // // dv = sum_over_rest(out_backprop * gamma * (x - m)) * // (-1/2) * (v + epsilon) ^ (-3/2) // // dm = sum_over_rest(out_backprop * gamma) * (-1 / rsqrt(v + epsilon)) // // dx = out_backprop * (gamma * rsqrt(v + epsilon)) db.device(d) = out_backprop.reshape(rest_by_depth).sum(reduction_axis); // scratch1 = rsqrt(v + epsilon) scratch1.device(d) = (var + var.constant(variance_epsilon)).rsqrt(); // scratch2 = sum_over_rest(out_backprop * (x - m)) scratch2.device(d) = (out_backprop.reshape(rest_by_depth) * (input.reshape(rest_by_depth) - mean.reshape(one_by_depth).broadcast(rest_by_one))) .sum(reduction_axis); if (scale_after_normalization) { dx.reshape(rest_by_depth).device(d) = out_backprop.reshape(rest_by_depth) * ((scratch1 * gamma) .eval() .reshape(one_by_depth) .broadcast(rest_by_one)); dm.device(d) = -db * (scratch1 * gamma).eval(); dg.device(d) = scratch2 * scratch1; } else { dx.reshape(rest_by_depth).device(d) = out_backprop.reshape(rest_by_depth) * scratch1.reshape(one_by_depth).broadcast(rest_by_one); dm.device(d) = -db * scratch1; dg.device(d) = dg.constant(static_cast(0.0)); // Gamma is not learned. } // scratch1 = - 1/2 * (var + epsilon) ^ (-3/2) scratch1.device(d) = scratch1 * scratch1.constant(static_cast(-0.5f)) / (var + var.constant(variance_epsilon)); if (scale_after_normalization) { dv.device(d) = scratch2 * (scratch1 * gamma).eval(); } else { dv.device(d) = scratch2 * scratch1; } } }; } // namespace functor } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_BATCH_NORM_OP_H_