aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/fused_batch_norm_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/fused_batch_norm_op.cc')
-rw-r--r--tensorflow/core/kernels/fused_batch_norm_op.cc92
1 files changed, 69 insertions, 23 deletions
diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cc b/tensorflow/core/kernels/fused_batch_norm_op.cc
index cc303e8dba..92b093eec6 100644
--- a/tensorflow/core/kernels/fused_batch_norm_op.cc
+++ b/tensorflow/core/kernels/fused_batch_norm_op.cc
@@ -17,7 +17,6 @@ limitations under the License.
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
-#include "tensorflow/core/kernels/fused_batch_norm_op.h"
#include "tensorflow/core/kernels/conv_2d.h"
#include "tensorflow/core/kernels/conv_ops_gpu.h"
#include "tensorflow/core/util/stream_executor_util.h"
@@ -28,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/kernels/fused_batch_norm_op.h"
#include "tensorflow/core/util/tensor_format.h"
namespace tensorflow {
@@ -39,7 +39,8 @@ namespace functor {
// Functor used by FusedBatchNormOp to do the computations.
template <typename Device, typename T>
struct FusedBatchNorm;
-// Functor used by FusedBatchNormGradOp to do the computations.
+// Functor used by FusedBatchNormGradOp to do the computations when
+// is_training=True.
template <typename Device, typename T>
struct FusedBatchNormGrad;
@@ -352,7 +353,7 @@ template <typename T>
struct FusedBatchNormGrad<GPUDevice, T> {
void operator()(OpKernelContext* context, const Tensor& y_backprop,
const Tensor& x, const Tensor& scale, const Tensor& mean,
- const Tensor& variance, T epsilon, Tensor* x_backprop,
+ const Tensor& inv_variance, T epsilon, Tensor* x_backprop,
Tensor* scale_backprop, Tensor* offset_backprop,
TensorFormat tensor_format) {
auto* stream = context->op_device_context()->stream();
@@ -441,16 +442,18 @@ struct FusedBatchNormGrad<GPUDevice, T> {
auto x_ptr = StreamExecutorUtil::AsDeviceMemory<T>(x_maybe_transformed);
auto scale_ptr = StreamExecutorUtil::AsDeviceMemory<T>(scale);
auto mean_ptr = StreamExecutorUtil::AsDeviceMemory<T>(mean);
- auto variance_ptr = StreamExecutorUtil::AsDeviceMemory<T>(variance);
+ auto inv_variance_ptr = StreamExecutorUtil::AsDeviceMemory<T>(inv_variance);
auto scale_backprop_ptr =
StreamExecutorUtil::AsDeviceMemory<T>(*scale_backprop);
auto offset_backprop_ptr =
StreamExecutorUtil::AsDeviceMemory<T>(*offset_backprop);
+ // the cudnn kernel outputs inverse variance in forward and reuse it in
+ // backward
bool cudnn_launch_status =
stream
->ThenBatchNormalizationBackward(
- y_backprop_ptr, x_ptr, scale_ptr, mean_ptr, variance_ptr,
+ y_backprop_ptr, x_ptr, scale_ptr, mean_ptr, inv_variance_ptr,
x_desc, scale_offset_desc, static_cast<double>(epsilon),
&x_backprop_ptr, &scale_backprop_ptr, &offset_backprop_ptr)
.ok();
@@ -468,6 +471,20 @@ struct FusedBatchNormGrad<GPUDevice, T> {
}
}
};
+
+// Forward declarations of the functor specializations for GPU.
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ void FusedBatchNormFreezeGrad<GPUDevice, T>::operator()( \
+ const GPUDevice& d, const Tensor& y_backprop_input, \
+ const Tensor& x_input, const Tensor& scale_input, \
+ const Tensor& mean_input, const Tensor& variance_input, T epsilon, \
+ Tensor* x_backprop_output, Tensor* scale_backprop_output, \
+ Tensor* offset_backprop_output, typename TTypes<T>::Vec scratch1, \
+ typename TTypes<T>::Vec scratch2); \
+ extern template struct FusedBatchNormFreezeGrad<GPUDevice, T>;
+DECLARE_GPU_SPEC(float);
+
#endif // GOOGLE_CUDA
} // namespace functor
@@ -511,7 +528,7 @@ class FusedBatchNormOp : public OpKernel {
if (is_training_) {
OP_REQUIRES(
context, estimated_mean.dim_size(0) == 0,
- errors::InvalidArgument("estimated_mean empty for training",
+ errors::InvalidArgument("estimated_mean must be empty for training",
estimated_mean.shape().DebugString()));
OP_REQUIRES(context, estimated_variance.dim_size(0) == 0,
errors::InvalidArgument(
@@ -531,14 +548,14 @@ class FusedBatchNormOp : public OpKernel {
Tensor* saved_mean = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(3, scale.shape(), &saved_mean));
- Tensor* saved_inv_var = nullptr;
- OP_REQUIRES_OK(context,
- context->allocate_output(4, scale.shape(), &saved_inv_var));
+ Tensor* saved_maybe_inv_var = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(4, scale.shape(),
+ &saved_maybe_inv_var));
functor::FusedBatchNorm<Device, T>()(
context, x, scale, offset, estimated_mean, estimated_variance, epsilon_,
- y, batch_mean, batch_var, saved_mean, saved_inv_var, tensor_format_,
- is_training_);
+ y, batch_mean, batch_var, saved_mean, saved_maybe_inv_var,
+ tensor_format_, is_training_);
}
private:
@@ -559,16 +576,21 @@ class FusedBatchNormGradOp : public OpKernel {
OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format));
OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_),
errors::InvalidArgument("Invalid data format"));
+ OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_));
}
void Compute(OpKernelContext* context) override {
const Tensor& y_backprop = context->input(0);
const Tensor& x = context->input(1);
const Tensor& scale = context->input(2);
- const Tensor& saved_mean = context->input(3);
- // The Eigen implementation saves variance in the forward pass, while cuDNN
+ // When is_training=True, batch mean and variance/inverted variance are
+ // saved in the forward pass to be reused here. When is_training=False,
+ // population mean and variance need to be forwarded here to compute the
+ // gradients.
+ const Tensor& saved_mean_or_pop_mean = context->input(3);
+ // The Eigen implementation saves variance in the forward pass, while cuDNN
// saves inverted variance.
- const Tensor& saved_maybe_inv_var = context->input(4);
+ const Tensor& saved_maybe_inv_var_or_pop_var = context->input(4);
OP_REQUIRES(context, y_backprop.dims() == 4,
errors::InvalidArgument("input must be 4-dimensional",
@@ -579,13 +601,14 @@ class FusedBatchNormGradOp : public OpKernel {
OP_REQUIRES(context, scale.dims() == 1,
errors::InvalidArgument("scale must be 1-dimensional",
scale.shape().DebugString()));
- OP_REQUIRES(context, saved_mean.dims() == 1,
- errors::InvalidArgument("saved mean must be 1-dimensional",
- saved_mean.shape().DebugString()));
OP_REQUIRES(
- context, saved_maybe_inv_var.dims() == 1,
- errors::InvalidArgument("saved variance must be 1-dimensional",
- saved_maybe_inv_var.shape().DebugString()));
+ context, saved_mean_or_pop_mean.dims() == 1,
+ errors::InvalidArgument("saved mean must be 1-dimensional",
+ saved_mean_or_pop_mean.shape().DebugString()));
+ OP_REQUIRES(context, saved_maybe_inv_var_or_pop_var.dims() == 1,
+ errors::InvalidArgument(
+ "saved variance must be 1-dimensional",
+ saved_maybe_inv_var_or_pop_var.shape().DebugString()));
Tensor* x_backprop = nullptr;
OP_REQUIRES_OK(context,
@@ -607,14 +630,37 @@ class FusedBatchNormGradOp : public OpKernel {
OP_REQUIRES_OK(
context, context->allocate_output(4, TensorShape({}), &placeholder_2));
- functor::FusedBatchNormGrad<Device, T>()(
- context, y_backprop, x, scale, saved_mean, saved_maybe_inv_var,
- epsilon_, x_backprop, scale_backprop, offset_backprop, tensor_format_);
+ if (is_training_) {
+ functor::FusedBatchNormGrad<Device, T>()(
+ context, y_backprop, x, scale, saved_mean_or_pop_mean,
+ saved_maybe_inv_var_or_pop_var, epsilon_, x_backprop, scale_backprop,
+ offset_backprop, tensor_format_);
+
+ } else {
+ // Necessary layout conversion is currently done in python.
+ CHECK(tensor_format_ == FORMAT_NHWC)
+ << "The implementation of FusedBatchNormGrad with is_training=False "
+ "only support "
+ << "NHWC tensor format for now.";
+ Tensor scratch1, scratch2;
+ OP_REQUIRES_OK(context,
+ context->allocate_temp(DataTypeToEnum<T>::value,
+ scale_offset_shape, &scratch1));
+ OP_REQUIRES_OK(context,
+ context->allocate_temp(DataTypeToEnum<T>::value,
+ scale_offset_shape, &scratch2));
+ functor::FusedBatchNormFreezeGrad<Device, T>()(
+ context->eigen_device<Device>(), y_backprop, x, scale,
+ saved_mean_or_pop_mean, saved_maybe_inv_var_or_pop_var, epsilon_,
+ x_backprop, scale_backprop, offset_backprop, scratch1.vec<T>(),
+ scratch2.vec<T>());
+ }
}
private:
T epsilon_;
TensorFormat tensor_format_;
+ bool is_training_;
};
REGISTER_KERNEL_BUILDER(Name("FusedBatchNorm").Device(DEVICE_CPU),