diff options
-rw-r--r-- | tensorflow/core/kernels/fused_batch_norm_op.cc | 26 |
1 files changed, 25 insertions, 1 deletions
diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cc b/tensorflow/core/kernels/fused_batch_norm_op.cc index 1688674eb7..09ba092f40 100644 --- a/tensorflow/core/kernels/fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/fused_batch_norm_op.cc @@ -566,6 +566,27 @@ class FusedBatchNormOp : public OpKernel { bool is_training_; }; +namespace { + +template <typename Device> +void FillZeros(Tensor* t); + +#if GOOGLE_CUDA +template <> +void FillZeros<GPUDevice>(Tensor* t) { + cudaMemset(const_cast<char*>(t->tensor_data().data()), 0, + t->tensor_data().size()); +} +#endif + +template <> +void FillZeros<CPUDevice>(Tensor* t) { + memset(const_cast<char*>(t->tensor_data().data()), 0, + t->tensor_data().size()); +} + +} // namespace + template <typename Device, typename T, typename U> class FusedBatchNormGradOp : public OpKernel { public: @@ -623,14 +644,17 @@ class FusedBatchNormGradOp : public OpKernel { Tensor* offset_backprop = nullptr; OP_REQUIRES_OK(context, context->allocate_output(2, scale_offset_shape, &offset_backprop)); - // two placeholders for estimated_mean and estimated_variance, which are + // Two placeholders for estimated_mean and estimated_variance, which are // used for inference and thus not needed here for gradient computation. + // They are filled with zeros so as to avoid NaN outputs. Tensor* placeholder_1 = nullptr; OP_REQUIRES_OK( context, context->allocate_output(3, TensorShape({}), &placeholder_1)); + FillZeros<Device>(placeholder_1); Tensor* placeholder_2 = nullptr; OP_REQUIRES_OK( context, context->allocate_output(4, TensorShape({}), &placeholder_2)); + FillZeros<Device>(placeholder_2); if (is_training_) { functor::FusedBatchNormGrad<Device, T, U>()( |