aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/batch_norm_op.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-02-22 10:13:44 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-22 10:39:27 -0800
commit90b9641b7e7da44343644e5b4e221478594a959d (patch)
tree844b5307b6aeaa0da743dce18b576a68ee6f525c /tensorflow/core/kernels/batch_norm_op.cc
parent719476a70fb823a8ac1200e1fc6825c982fcc7d3 (diff)
Fix bug in buffer forwarding for BatchNorm.
Change: 148234275
Diffstat (limited to 'tensorflow/core/kernels/batch_norm_op.cc')
-rw-r--r--tensorflow/core/kernels/batch_norm_op.cc15
1 files changed, 8 insertions, 7 deletions
diff --git a/tensorflow/core/kernels/batch_norm_op.cc b/tensorflow/core/kernels/batch_norm_op.cc
index 7c95d4dd20..77d7b34e69 100644
--- a/tensorflow/core/kernels/batch_norm_op.cc
+++ b/tensorflow/core/kernels/batch_norm_op.cc
@@ -115,25 +115,26 @@ class BatchNormGradOp : public OpKernel {
out_backprop.shape().DebugString()));
Tensor* dx = nullptr;
- if (!context->forward_input_to_output(0, 0, &dx)) {
+ if (!context->forward_input_to_output_with_shape(0, 0, input.shape(),
+ &dx) &&
+ !context->forward_input_to_output_with_shape(4, 0, input.shape(),
+ &dx)) {
OP_REQUIRES_OK(context, context->allocate_output(0, input.shape(), &dx));
}
Tensor* dm = nullptr;
- if (!context->forward_input_to_output(1, 1, &dm)) {
+ if (!context->forward_input_to_output_with_shape(1, 1, mean.shape(), &dm)) {
OP_REQUIRES_OK(context, context->allocate_output(1, mean.shape(), &dm));
}
Tensor* dv = nullptr;
- if (!context->forward_input_to_output(2, 2, &dv)) {
+ if (!context->forward_input_to_output_with_shape(2, 2, var.shape(), &dv)) {
OP_REQUIRES_OK(context, context->allocate_output(2, var.shape(), &dv));
}
Tensor* db = nullptr;
- if (!context->forward_input_to_output(3, 3, &db)) {
+ if (!context->forward_input_to_output_with_shape(3, 3, mean.shape(), &db)) {
OP_REQUIRES_OK(context, context->allocate_output(3, mean.shape(), &db));
}
Tensor* dg = nullptr;
- if (!context->forward_input_to_output(4, 4, &dg)) {
- OP_REQUIRES_OK(context, context->allocate_output(4, gamma.shape(), &dg));
- }
+ OP_REQUIRES_OK(context, context->allocate_output(4, gamma.shape(), &dg));
// Scratch buffer of [depth] dimension, aka the 4th dimension of input,
// which is dim_size(3), for calculating various combinations of