aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/kernels/fused_batch_norm_op.cc26
1 files changed, 25 insertions, 1 deletions
diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cc b/tensorflow/core/kernels/fused_batch_norm_op.cc
index 1688674eb7..09ba092f40 100644
--- a/tensorflow/core/kernels/fused_batch_norm_op.cc
+++ b/tensorflow/core/kernels/fused_batch_norm_op.cc
@@ -566,6 +566,27 @@ class FusedBatchNormOp : public OpKernel {
bool is_training_;
};
+namespace {
+
+template <typename Device>
+void FillZeros(Tensor* t);
+
+#if GOOGLE_CUDA
+template <>
+void FillZeros<GPUDevice>(Tensor* t) {
+ cudaMemset(const_cast<char*>(t->tensor_data().data()), 0,
+ t->tensor_data().size());
+}
+#endif
+
+template <>
+void FillZeros<CPUDevice>(Tensor* t) {
+ memset(const_cast<char*>(t->tensor_data().data()), 0,
+ t->tensor_data().size());
+}
+
+} // namespace
+
template <typename Device, typename T, typename U>
class FusedBatchNormGradOp : public OpKernel {
public:
@@ -623,14 +644,17 @@ class FusedBatchNormGradOp : public OpKernel {
Tensor* offset_backprop = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(2, scale_offset_shape,
&offset_backprop));
- // two placeholders for estimated_mean and estimated_variance, which are
+ // Two placeholders for estimated_mean and estimated_variance, which are
// used for inference and thus not needed here for gradient computation.
+ // They are filled with zeros so as to avoid NaN outputs.
Tensor* placeholder_1 = nullptr;
OP_REQUIRES_OK(
context, context->allocate_output(3, TensorShape({}), &placeholder_1));
+ FillZeros<Device>(placeholder_1);
Tensor* placeholder_2 = nullptr;
OP_REQUIRES_OK(
context, context->allocate_output(4, TensorShape({}), &placeholder_2));
+ FillZeros<Device>(placeholder_2);
if (is_training_) {
functor::FusedBatchNormGrad<Device, T, U>()(