diff options
Diffstat (limited to 'tensorflow/core/kernels/fused_batch_norm_op.h')
-rw-r--r-- | tensorflow/core/kernels/fused_batch_norm_op.h | 71 |
1 files changed, 71 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/fused_batch_norm_op.h b/tensorflow/core/kernels/fused_batch_norm_op.h index da8692caad..1566cfa4dc 100644 --- a/tensorflow/core/kernels/fused_batch_norm_op.h +++ b/tensorflow/core/kernels/fused_batch_norm_op.h @@ -17,9 +17,14 @@ limitations under the License. #define TENSORFLOW_KERNELS_FUSED_BATCH_NORM_OP_H_ #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_types.h" namespace tensorflow { namespace functor { + +#if GOOGLE_CUDA + // There is a behavior difference between cuDNN v4 and v5 with regard to the // scaling factor for function cudnnBatchNormalizationForwardInference. // This function corrects the scaling factor if cuDNN v4 is used, so that @@ -43,6 +48,72 @@ struct InvVarianceToVariance { void operator()(const Eigen::GpuDevice& d, double epsilon, int sample_size, int channels, T* variance); }; + +#endif // GOOGLE_CUDA + +// Functor used by FusedBatchNormGradOp to do the computations when +// is_training=False. Both CPU and GPU will use this functor. +template <typename Device, typename T> +struct FusedBatchNormFreezeGrad { + void operator()(const Device& d, const Tensor& y_backprop_input, + const Tensor& x_input, const Tensor& scale_input, + const Tensor& pop_mean_input, + const Tensor& pop_variance_input, T epsilon, + Tensor* x_backprop_output, Tensor* scale_backprop_output, + Tensor* offset_backprop_output, + typename TTypes<T>::Vec scratch1, + typename TTypes<T>::Vec scratch2) { + typename TTypes<T, 4>::ConstTensor y_backprop( + y_backprop_input.tensor<T, 4>()); + typename TTypes<T, 4>::ConstTensor input(x_input.tensor<T, 4>()); + typename TTypes<T>::ConstVec scale(scale_input.vec<T>()); + typename TTypes<T>::ConstVec pop_mean(pop_mean_input.vec<T>()); + typename TTypes<T>::ConstVec pop_var(pop_variance_input.vec<T>()); + typename TTypes<T, 4>::Tensor x_backprop(x_backprop_output->tensor<T, 4>()); + typename TTypes<T>::Vec scale_backprop(scale_backprop_output->vec<T>()); + typename TTypes<T>::Vec offset_backprop(offset_backprop_output->vec<T>()); + + const int depth = pop_mean.dimension(0); + const int rest_size = input.size() / depth; + + Eigen::DSizes<Eigen::Index, 2> rest_by_depth(rest_size, depth); +#if !defined(EIGEN_HAS_INDEX_LIST) + Eigen::DSizes<Eigen::Index, 2> one_by_depth(1, depth); + Eigen::array<int, 1> reduction_axis{0}; + Eigen::array<int, 2> rest_by_one({rest_size, 1}); +#else + Eigen::IndexList<Eigen::type2index<1>, Eigen::Index> one_by_depth; + one_by_depth.set(1, depth); + Eigen::IndexList<Eigen::type2index<0> > reduction_axis; + Eigen::IndexList<Eigen::Index, Eigen::type2index<1> > rest_by_one; + rest_by_one.set(0, rest_size); +#endif + + // offset_backprop = sum(y_backprop) + // scale_backprop = y_backprop * ((x - pop_mean) * rsqrt(pop_var + epsilon)) + // x_backprop = y_backprop * (scale * rsqrt(pop_var + epsilon)) + offset_backprop.device(d) = + y_backprop.reshape(rest_by_depth).sum(reduction_axis); + + // scratch1 = rsqrt(pop_var + epsilon) + scratch1.device(d) = (pop_var + pop_var.constant(epsilon)).rsqrt(); + + // scratch2 = sum(y_backprop * (x - mean)) + scratch2.device(d) = + (y_backprop.reshape(rest_by_depth) * + (input.reshape(rest_by_depth) - + pop_mean.reshape(one_by_depth).broadcast(rest_by_one))) + .sum(reduction_axis); + + x_backprop.reshape(rest_by_depth).device(d) = + y_backprop.reshape(rest_by_depth) * ((scratch1 * scale) + .eval() + .reshape(one_by_depth) + .broadcast(rest_by_one)); + scale_backprop.device(d) = scratch2 * scratch1; + } +}; + } // namespace functor } // namespace tensorflow |