diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-02-22 10:13:44 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-02-22 10:39:27 -0800 |
commit | 90b9641b7e7da44343644e5b4e221478594a959d (patch) | |
tree | 844b5307b6aeaa0da743dce18b576a68ee6f525c /tensorflow/core/kernels/batch_norm_op.cc | |
parent | 719476a70fb823a8ac1200e1fc6825c982fcc7d3 (diff) |
Fix bug in buffer forwarding for BatchNorm.
Change: 148234275
Diffstat (limited to 'tensorflow/core/kernels/batch_norm_op.cc')
-rw-r--r-- | tensorflow/core/kernels/batch_norm_op.cc | 15 |
1 files changed, 8 insertions, 7 deletions
diff --git a/tensorflow/core/kernels/batch_norm_op.cc b/tensorflow/core/kernels/batch_norm_op.cc index 7c95d4dd20..77d7b34e69 100644 --- a/tensorflow/core/kernels/batch_norm_op.cc +++ b/tensorflow/core/kernels/batch_norm_op.cc @@ -115,25 +115,26 @@ class BatchNormGradOp : public OpKernel { out_backprop.shape().DebugString())); Tensor* dx = nullptr; - if (!context->forward_input_to_output(0, 0, &dx)) { + if (!context->forward_input_to_output_with_shape(0, 0, input.shape(), + &dx) && + !context->forward_input_to_output_with_shape(4, 0, input.shape(), + &dx)) { OP_REQUIRES_OK(context, context->allocate_output(0, input.shape(), &dx)); } Tensor* dm = nullptr; - if (!context->forward_input_to_output(1, 1, &dm)) { + if (!context->forward_input_to_output_with_shape(1, 1, mean.shape(), &dm)) { OP_REQUIRES_OK(context, context->allocate_output(1, mean.shape(), &dm)); } Tensor* dv = nullptr; - if (!context->forward_input_to_output(2, 2, &dv)) { + if (!context->forward_input_to_output_with_shape(2, 2, var.shape(), &dv)) { OP_REQUIRES_OK(context, context->allocate_output(2, var.shape(), &dv)); } Tensor* db = nullptr; - if (!context->forward_input_to_output(3, 3, &db)) { + if (!context->forward_input_to_output_with_shape(3, 3, mean.shape(), &db)) { OP_REQUIRES_OK(context, context->allocate_output(3, mean.shape(), &db)); } Tensor* dg = nullptr; - if (!context->forward_input_to_output(4, 4, &dg)) { - OP_REQUIRES_OK(context, context->allocate_output(4, gamma.shape(), &dg)); - } + OP_REQUIRES_OK(context, context->allocate_output(4, gamma.shape(), &dg)); // Scratch buffer of [depth] dimension, aka the 4th dimension of input, // which is dim_size(3), for calculating various combinations of |