diff options
Diffstat (limited to 'tensorflow/core/kernels/batch_norm_op.cc')
-rw-r--r-- | tensorflow/core/kernels/batch_norm_op.cc | 223 |
1 files changed, 223 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/batch_norm_op.cc b/tensorflow/core/kernels/batch_norm_op.cc new file mode 100644 index 0000000000..c67c921631 --- /dev/null +++ b/tensorflow/core/kernels/batch_norm_op.cc @@ -0,0 +1,223 @@ +// See docs in ../ops/nn_ops.cc. + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/batch_norm_op.h" +#include "tensorflow/core/public/tensor.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +template <typename Device, typename T> +class BatchNormOp : public OpKernel { + public: + explicit BatchNormOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, + context->GetAttr("variance_epsilon", &variance_epsilon_)); + OP_REQUIRES_OK(context, context->GetAttr("scale_after_normalization", + &scale_after_normalization_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + const Tensor& mean = context->input(1); + const Tensor& var = context->input(2); + const Tensor& beta = context->input(3); + const Tensor& gamma = context->input(4); + + OP_REQUIRES(context, input.dims() == 4, + errors::InvalidArgument("input must be 4-dimensional", + input.shape().ShortDebugString())); + OP_REQUIRES(context, mean.dims() == 1, + errors::InvalidArgument("mean must be 1-dimensional", + mean.shape().ShortDebugString())); + OP_REQUIRES(context, var.dims() == 1, + errors::InvalidArgument("var must be 1-dimensional", + var.shape().ShortDebugString())); + OP_REQUIRES(context, beta.dims() == 1, + errors::InvalidArgument("beta must be 1-dimensional", + beta.shape().ShortDebugString())); + OP_REQUIRES(context, gamma.dims() == 1, + errors::InvalidArgument("gamma must be 1-dimensional", + gamma.shape().ShortDebugString())); + + Tensor* output = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, input.shape(), &output)); + + functor::BatchNorm<Device, T>()( + context->eigen_device<Device>(), input.tensor<T, 4>(), mean.vec<T>(), + var.vec<T>(), beta.vec<T>(), gamma.vec<T>(), variance_epsilon_, + scale_after_normalization_, output->tensor<T, 4>()); + } + + private: + float variance_epsilon_; + bool scale_after_normalization_; +}; + +template <typename Device, typename T> +class BatchNormGradOp : public OpKernel { + public: + explicit BatchNormGradOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, + context->GetAttr("variance_epsilon", &variance_epsilon_)); + OP_REQUIRES_OK(context, context->GetAttr("scale_after_normalization", + &scale_after_normalization_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + const Tensor& mean = context->input(1); + const Tensor& var = context->input(2); + const Tensor& gamma = context->input(3); + const Tensor& out_backprop = context->input(4); + + OP_REQUIRES(context, input.dims() == 4, + errors::InvalidArgument("input must be 4-dimensional", + input.shape().ShortDebugString())); + OP_REQUIRES(context, mean.dims() == 1, + errors::InvalidArgument("mean must be 1-dimensional", + mean.shape().ShortDebugString())); + OP_REQUIRES(context, var.dims() == 1, + errors::InvalidArgument("var must be 1-dimensional", + var.shape().ShortDebugString())); + OP_REQUIRES(context, gamma.dims() == 1, + errors::InvalidArgument("gamma must be 1-dimensional", + gamma.shape().ShortDebugString())); + OP_REQUIRES( + context, out_backprop.dims() == 4, + errors::InvalidArgument("out_backprop must be 4-dimensional", + out_backprop.shape().ShortDebugString())); + + Tensor* dx = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, input.shape(), &dx)); + Tensor* dm = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(1, mean.shape(), &dm)); + Tensor* dv = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(2, var.shape(), &dv)); + Tensor* db = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(3, mean.shape(), &db)); + Tensor* dg = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(4, gamma.shape(), &dg)); + + // Scratch buffer of [depth] dimension, aka the 4th dimension of input, + // which is dim_size(3), for calculating various combinations of + // (var + epsilon). + Tensor scratch1; + OP_REQUIRES_OK(context, context->allocate_temp( + DataTypeToEnum<T>::value, + TensorShape({input.dim_size(3)}), &scratch1)); + + // Scratch buffer of [depth] dimension for saving intermediate calculation + // values. + Tensor scratch2; + OP_REQUIRES_OK(context, context->allocate_temp( + DataTypeToEnum<T>::value, + TensorShape({input.dim_size(3)}), &scratch2)); + + functor::BatchNormGrad<Device, T>()( + context->eigen_device<Device>(), input.tensor<T, 4>(), mean.vec<T>(), + var.vec<T>(), gamma.vec<T>(), out_backprop.tensor<T, 4>(), + variance_epsilon_, scale_after_normalization_, dx->tensor<T, 4>(), + dm->vec<T>(), dv->vec<T>(), db->vec<T>(), dg->vec<T>(), + scratch1.vec<T>(), scratch2.vec<T>()); + } + + private: + float variance_epsilon_; + bool scale_after_normalization_; +}; + +#define REGISTER_KERNEL(T) \ + REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalization") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T"), \ + BatchNormOp<CPUDevice, T>); + +REGISTER_KERNEL(float); +REGISTER_KERNEL(double); +#undef REGISTER_KERNEL + +#if GOOGLE_CUDA +// Forward declarations of the functor specializations for GPU. +namespace functor { +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void BatchNorm<GPUDevice, T>::operator()( \ + const GPUDevice& d, typename TTypes<T, 4>::ConstTensor input, \ + typename TTypes<T>::ConstVec mean, typename TTypes<T>::ConstVec var, \ + typename TTypes<T>::ConstVec beta, typename TTypes<T>::ConstVec gamma, \ + float variance_epsilon, bool scale_after_normalization, \ + typename TTypes<T, 4>::Tensor output); \ + extern template struct BatchNorm<GPUDevice, T>; + +#define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T); + +DECLARE_GPU_SPECS(float); +#undef DECLARE_GPU_SPEC +} // namespace functor + +// Registration of the GPU implementations. +#define REGISTER_GPU_KERNEL(T) \ + REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalization") \ + .Device(DEVICE_GPU) \ + .TypeConstraint<T>("T"), \ + BatchNormOp<GPUDevice, T>); + +REGISTER_GPU_KERNEL(float); +#undef REGISTER_GPU_KERNEL + +#endif // GOOGLE_CUDA + +#define REGISTER_KERNEL(T) \ + REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalizationGrad") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T"), \ + BatchNormGradOp<CPUDevice, T>); + +REGISTER_KERNEL(float); +REGISTER_KERNEL(double); +#undef REGISTER_KERNEL + +#if GOOGLE_CUDA +// Forward declarations of the functor specializations for GPU. +namespace functor { +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void BatchNormGrad<GPUDevice, T>::operator()( \ + const GPUDevice& d, typename TTypes<T, 4>::ConstTensor input, \ + typename TTypes<T>::ConstVec mean, typename TTypes<T>::ConstVec var, \ + typename TTypes<T>::ConstVec gamma, \ + typename TTypes<T, 4>::ConstTensor out_backprop, float variance_epsilon, \ + bool scale_after_normalization, typename TTypes<T, 4>::Tensor dx, \ + typename TTypes<T>::Vec dm, typename TTypes<T>::Vec dv, \ + typename TTypes<T>::Vec db, typename TTypes<T>::Vec dg, \ + typename TTypes<T>::Vec scratch1, typename TTypes<T>::Vec scratch2); \ + extern template struct BatchNormGrad<GPUDevice, T>; + +#define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T); + +DECLARE_GPU_SPECS(float); +#undef DECLARE_GPU_SPEC +} // namespace functor + +// Registration of the GPU implementations. +#define REGISTER_GPU_KERNEL(T) \ + REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalizationGrad") \ + .Device(DEVICE_GPU) \ + .TypeConstraint<T>("T"), \ + BatchNormGradOp<GPUDevice, T>); + +REGISTER_GPU_KERNEL(float); +#undef REGISTER_GPU_KERNEL + +#endif // GOOGLE_CUDA + +} // namespace tensorflow |