// See docs in ../ops/nn_ops.cc. #define EIGEN_USE_THREADS #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/kernels/batch_norm_op.h" #include "tensorflow/core/public/tensor.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; template class BatchNormOp : public OpKernel { public: explicit BatchNormOp(OpKernelConstruction* context) : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("variance_epsilon", &variance_epsilon_)); OP_REQUIRES_OK(context, context->GetAttr("scale_after_normalization", &scale_after_normalization_)); } void Compute(OpKernelContext* context) override { const Tensor& input = context->input(0); const Tensor& mean = context->input(1); const Tensor& var = context->input(2); const Tensor& beta = context->input(3); const Tensor& gamma = context->input(4); OP_REQUIRES(context, input.dims() == 4, errors::InvalidArgument("input must be 4-dimensional", input.shape().ShortDebugString())); OP_REQUIRES(context, mean.dims() == 1, errors::InvalidArgument("mean must be 1-dimensional", mean.shape().ShortDebugString())); OP_REQUIRES(context, var.dims() == 1, errors::InvalidArgument("var must be 1-dimensional", var.shape().ShortDebugString())); OP_REQUIRES(context, beta.dims() == 1, errors::InvalidArgument("beta must be 1-dimensional", beta.shape().ShortDebugString())); OP_REQUIRES(context, gamma.dims() == 1, errors::InvalidArgument("gamma must be 1-dimensional", gamma.shape().ShortDebugString())); Tensor* output = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, input.shape(), &output)); functor::BatchNorm()( context->eigen_device(), input.tensor(), mean.vec(), var.vec(), beta.vec(), gamma.vec(), variance_epsilon_, scale_after_normalization_, output->tensor()); } private: float variance_epsilon_; bool scale_after_normalization_; }; template class BatchNormGradOp : public OpKernel { public: explicit BatchNormGradOp(OpKernelConstruction* context) : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("variance_epsilon", &variance_epsilon_)); OP_REQUIRES_OK(context, context->GetAttr("scale_after_normalization", &scale_after_normalization_)); } void Compute(OpKernelContext* context) override { const Tensor& input = context->input(0); const Tensor& mean = context->input(1); const Tensor& var = context->input(2); const Tensor& gamma = context->input(3); const Tensor& out_backprop = context->input(4); OP_REQUIRES(context, input.dims() == 4, errors::InvalidArgument("input must be 4-dimensional", input.shape().ShortDebugString())); OP_REQUIRES(context, mean.dims() == 1, errors::InvalidArgument("mean must be 1-dimensional", mean.shape().ShortDebugString())); OP_REQUIRES(context, var.dims() == 1, errors::InvalidArgument("var must be 1-dimensional", var.shape().ShortDebugString())); OP_REQUIRES(context, gamma.dims() == 1, errors::InvalidArgument("gamma must be 1-dimensional", gamma.shape().ShortDebugString())); OP_REQUIRES( context, out_backprop.dims() == 4, errors::InvalidArgument("out_backprop must be 4-dimensional", out_backprop.shape().ShortDebugString())); Tensor* dx = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, input.shape(), &dx)); Tensor* dm = nullptr; OP_REQUIRES_OK(context, context->allocate_output(1, mean.shape(), &dm)); Tensor* dv = nullptr; OP_REQUIRES_OK(context, context->allocate_output(2, var.shape(), &dv)); Tensor* db = nullptr; OP_REQUIRES_OK(context, context->allocate_output(3, mean.shape(), &db)); Tensor* dg = nullptr; OP_REQUIRES_OK(context, context->allocate_output(4, gamma.shape(), &dg)); // Scratch buffer of [depth] dimension, aka the 4th dimension of input, // which is dim_size(3), for calculating various combinations of // (var + epsilon). Tensor scratch1; OP_REQUIRES_OK(context, context->allocate_temp( DataTypeToEnum::value, TensorShape({input.dim_size(3)}), &scratch1)); // Scratch buffer of [depth] dimension for saving intermediate calculation // values. Tensor scratch2; OP_REQUIRES_OK(context, context->allocate_temp( DataTypeToEnum::value, TensorShape({input.dim_size(3)}), &scratch2)); functor::BatchNormGrad()( context->eigen_device(), input.tensor(), mean.vec(), var.vec(), gamma.vec(), out_backprop.tensor(), variance_epsilon_, scale_after_normalization_, dx->tensor(), dm->vec(), dv->vec(), db->vec(), dg->vec(), scratch1.vec(), scratch2.vec()); } private: float variance_epsilon_; bool scale_after_normalization_; }; #define REGISTER_KERNEL(T) \ REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalization") \ .Device(DEVICE_CPU) \ .TypeConstraint("T"), \ BatchNormOp); REGISTER_KERNEL(float); REGISTER_KERNEL(double); #undef REGISTER_KERNEL #if GOOGLE_CUDA // Forward declarations of the functor specializations for GPU. namespace functor { #define DECLARE_GPU_SPEC(T) \ template <> \ void BatchNorm::operator()( \ const GPUDevice& d, typename TTypes::ConstTensor input, \ typename TTypes::ConstVec mean, typename TTypes::ConstVec var, \ typename TTypes::ConstVec beta, typename TTypes::ConstVec gamma, \ float variance_epsilon, bool scale_after_normalization, \ typename TTypes::Tensor output); \ extern template struct BatchNorm; #define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T); DECLARE_GPU_SPECS(float); #undef DECLARE_GPU_SPEC } // namespace functor // Registration of the GPU implementations. #define REGISTER_GPU_KERNEL(T) \ REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalization") \ .Device(DEVICE_GPU) \ .TypeConstraint("T"), \ BatchNormOp); REGISTER_GPU_KERNEL(float); #undef REGISTER_GPU_KERNEL #endif // GOOGLE_CUDA #define REGISTER_KERNEL(T) \ REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalizationGrad") \ .Device(DEVICE_CPU) \ .TypeConstraint("T"), \ BatchNormGradOp); REGISTER_KERNEL(float); REGISTER_KERNEL(double); #undef REGISTER_KERNEL #if GOOGLE_CUDA // Forward declarations of the functor specializations for GPU. namespace functor { #define DECLARE_GPU_SPEC(T) \ template <> \ void BatchNormGrad::operator()( \ const GPUDevice& d, typename TTypes::ConstTensor input, \ typename TTypes::ConstVec mean, typename TTypes::ConstVec var, \ typename TTypes::ConstVec gamma, \ typename TTypes::ConstTensor out_backprop, float variance_epsilon, \ bool scale_after_normalization, typename TTypes::Tensor dx, \ typename TTypes::Vec dm, typename TTypes::Vec dv, \ typename TTypes::Vec db, typename TTypes::Vec dg, \ typename TTypes::Vec scratch1, typename TTypes::Vec scratch2); \ extern template struct BatchNormGrad; #define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T); DECLARE_GPU_SPECS(float); #undef DECLARE_GPU_SPEC } // namespace functor // Registration of the GPU implementations. #define REGISTER_GPU_KERNEL(T) \ REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalizationGrad") \ .Device(DEVICE_GPU) \ .TypeConstraint("T"), \ BatchNormGradOp); REGISTER_GPU_KERNEL(float); #undef REGISTER_GPU_KERNEL #endif // GOOGLE_CUDA } // namespace tensorflow