diff options
Diffstat (limited to 'tensorflow/core/kernels/fused_batch_norm_op.cc')
-rw-r--r-- | tensorflow/core/kernels/fused_batch_norm_op.cc | 92 |
1 files changed, 69 insertions, 23 deletions
diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cc b/tensorflow/core/kernels/fused_batch_norm_op.cc index cc303e8dba..92b093eec6 100644 --- a/tensorflow/core/kernels/fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/fused_batch_norm_op.cc @@ -17,7 +17,6 @@ limitations under the License. #if GOOGLE_CUDA #define EIGEN_USE_GPU -#include "tensorflow/core/kernels/fused_batch_norm_op.h" #include "tensorflow/core/kernels/conv_2d.h" #include "tensorflow/core/kernels/conv_ops_gpu.h" #include "tensorflow/core/util/stream_executor_util.h" @@ -28,6 +27,7 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/fused_batch_norm_op.h" #include "tensorflow/core/util/tensor_format.h" namespace tensorflow { @@ -39,7 +39,8 @@ namespace functor { // Functor used by FusedBatchNormOp to do the computations. template <typename Device, typename T> struct FusedBatchNorm; -// Functor used by FusedBatchNormGradOp to do the computations. +// Functor used by FusedBatchNormGradOp to do the computations when +// is_training=True. template <typename Device, typename T> struct FusedBatchNormGrad; @@ -352,7 +353,7 @@ template <typename T> struct FusedBatchNormGrad<GPUDevice, T> { void operator()(OpKernelContext* context, const Tensor& y_backprop, const Tensor& x, const Tensor& scale, const Tensor& mean, - const Tensor& variance, T epsilon, Tensor* x_backprop, + const Tensor& inv_variance, T epsilon, Tensor* x_backprop, Tensor* scale_backprop, Tensor* offset_backprop, TensorFormat tensor_format) { auto* stream = context->op_device_context()->stream(); @@ -441,16 +442,18 @@ struct FusedBatchNormGrad<GPUDevice, T> { auto x_ptr = StreamExecutorUtil::AsDeviceMemory<T>(x_maybe_transformed); auto scale_ptr = StreamExecutorUtil::AsDeviceMemory<T>(scale); auto mean_ptr = StreamExecutorUtil::AsDeviceMemory<T>(mean); - auto variance_ptr = StreamExecutorUtil::AsDeviceMemory<T>(variance); + auto inv_variance_ptr = StreamExecutorUtil::AsDeviceMemory<T>(inv_variance); auto scale_backprop_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*scale_backprop); auto offset_backprop_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*offset_backprop); + // the cudnn kernel outputs inverse variance in forward and reuse it in + // backward bool cudnn_launch_status = stream ->ThenBatchNormalizationBackward( - y_backprop_ptr, x_ptr, scale_ptr, mean_ptr, variance_ptr, + y_backprop_ptr, x_ptr, scale_ptr, mean_ptr, inv_variance_ptr, x_desc, scale_offset_desc, static_cast<double>(epsilon), &x_backprop_ptr, &scale_backprop_ptr, &offset_backprop_ptr) .ok(); @@ -468,6 +471,20 @@ struct FusedBatchNormGrad<GPUDevice, T> { } } }; + +// Forward declarations of the functor specializations for GPU. +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void FusedBatchNormFreezeGrad<GPUDevice, T>::operator()( \ + const GPUDevice& d, const Tensor& y_backprop_input, \ + const Tensor& x_input, const Tensor& scale_input, \ + const Tensor& mean_input, const Tensor& variance_input, T epsilon, \ + Tensor* x_backprop_output, Tensor* scale_backprop_output, \ + Tensor* offset_backprop_output, typename TTypes<T>::Vec scratch1, \ + typename TTypes<T>::Vec scratch2); \ + extern template struct FusedBatchNormFreezeGrad<GPUDevice, T>; +DECLARE_GPU_SPEC(float); + #endif // GOOGLE_CUDA } // namespace functor @@ -511,7 +528,7 @@ class FusedBatchNormOp : public OpKernel { if (is_training_) { OP_REQUIRES( context, estimated_mean.dim_size(0) == 0, - errors::InvalidArgument("estimated_mean empty for training", + errors::InvalidArgument("estimated_mean must be empty for training", estimated_mean.shape().DebugString())); OP_REQUIRES(context, estimated_variance.dim_size(0) == 0, errors::InvalidArgument( @@ -531,14 +548,14 @@ class FusedBatchNormOp : public OpKernel { Tensor* saved_mean = nullptr; OP_REQUIRES_OK(context, context->allocate_output(3, scale.shape(), &saved_mean)); - Tensor* saved_inv_var = nullptr; - OP_REQUIRES_OK(context, - context->allocate_output(4, scale.shape(), &saved_inv_var)); + Tensor* saved_maybe_inv_var = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(4, scale.shape(), + &saved_maybe_inv_var)); functor::FusedBatchNorm<Device, T>()( context, x, scale, offset, estimated_mean, estimated_variance, epsilon_, - y, batch_mean, batch_var, saved_mean, saved_inv_var, tensor_format_, - is_training_); + y, batch_mean, batch_var, saved_mean, saved_maybe_inv_var, + tensor_format_, is_training_); } private: @@ -559,16 +576,21 @@ class FusedBatchNormGradOp : public OpKernel { OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format)); OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_), errors::InvalidArgument("Invalid data format")); + OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_)); } void Compute(OpKernelContext* context) override { const Tensor& y_backprop = context->input(0); const Tensor& x = context->input(1); const Tensor& scale = context->input(2); - const Tensor& saved_mean = context->input(3); - // The Eigen implementation saves variance in the forward pass, while cuDNN + // When is_training=True, batch mean and variance/inverted variance are + // saved in the forward pass to be reused here. When is_training=False, + // population mean and variance need to be forwarded here to compute the + // gradients. + const Tensor& saved_mean_or_pop_mean = context->input(3); + // The Eigen implementation saves variance in the forward pass, while cuDNN // saves inverted variance. - const Tensor& saved_maybe_inv_var = context->input(4); + const Tensor& saved_maybe_inv_var_or_pop_var = context->input(4); OP_REQUIRES(context, y_backprop.dims() == 4, errors::InvalidArgument("input must be 4-dimensional", @@ -579,13 +601,14 @@ class FusedBatchNormGradOp : public OpKernel { OP_REQUIRES(context, scale.dims() == 1, errors::InvalidArgument("scale must be 1-dimensional", scale.shape().DebugString())); - OP_REQUIRES(context, saved_mean.dims() == 1, - errors::InvalidArgument("saved mean must be 1-dimensional", - saved_mean.shape().DebugString())); OP_REQUIRES( - context, saved_maybe_inv_var.dims() == 1, - errors::InvalidArgument("saved variance must be 1-dimensional", - saved_maybe_inv_var.shape().DebugString())); + context, saved_mean_or_pop_mean.dims() == 1, + errors::InvalidArgument("saved mean must be 1-dimensional", + saved_mean_or_pop_mean.shape().DebugString())); + OP_REQUIRES(context, saved_maybe_inv_var_or_pop_var.dims() == 1, + errors::InvalidArgument( + "saved variance must be 1-dimensional", + saved_maybe_inv_var_or_pop_var.shape().DebugString())); Tensor* x_backprop = nullptr; OP_REQUIRES_OK(context, @@ -607,14 +630,37 @@ class FusedBatchNormGradOp : public OpKernel { OP_REQUIRES_OK( context, context->allocate_output(4, TensorShape({}), &placeholder_2)); - functor::FusedBatchNormGrad<Device, T>()( - context, y_backprop, x, scale, saved_mean, saved_maybe_inv_var, - epsilon_, x_backprop, scale_backprop, offset_backprop, tensor_format_); + if (is_training_) { + functor::FusedBatchNormGrad<Device, T>()( + context, y_backprop, x, scale, saved_mean_or_pop_mean, + saved_maybe_inv_var_or_pop_var, epsilon_, x_backprop, scale_backprop, + offset_backprop, tensor_format_); + + } else { + // Necessary layout conversion is currently done in python. + CHECK(tensor_format_ == FORMAT_NHWC) + << "The implementation of FusedBatchNormGrad with is_training=False " + "only support " + << "NHWC tensor format for now."; + Tensor scratch1, scratch2; + OP_REQUIRES_OK(context, + context->allocate_temp(DataTypeToEnum<T>::value, + scale_offset_shape, &scratch1)); + OP_REQUIRES_OK(context, + context->allocate_temp(DataTypeToEnum<T>::value, + scale_offset_shape, &scratch2)); + functor::FusedBatchNormFreezeGrad<Device, T>()( + context->eigen_device<Device>(), y_backprop, x, scale, + saved_mean_or_pop_mean, saved_maybe_inv_var_or_pop_var, epsilon_, + x_backprop, scale_backprop, offset_backprop, scratch1.vec<T>(), + scratch2.vec<T>()); + } } private: T epsilon_; TensorFormat tensor_format_; + bool is_training_; }; REGISTER_KERNEL_BUILDER(Name("FusedBatchNorm").Device(DEVICE_CPU), |