aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/fused_batch_norm_op.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/fused_batch_norm_op.h')
-rw-r--r--tensorflow/core/kernels/fused_batch_norm_op.h22
1 files changed, 12 insertions, 10 deletions
diff --git a/tensorflow/core/kernels/fused_batch_norm_op.h b/tensorflow/core/kernels/fused_batch_norm_op.h
index 38b24d7011..3af104bf95 100644
--- a/tensorflow/core/kernels/fused_batch_norm_op.h
+++ b/tensorflow/core/kernels/fused_batch_norm_op.h
@@ -92,26 +92,28 @@ struct FusedBatchNormFreezeGrad {
// offset_backprop = sum(y_backprop)
// scale_backprop = y_backprop * ((x - pop_mean) * rsqrt(pop_var + epsilon))
// x_backprop = y_backprop * (scale * rsqrt(pop_var + epsilon))
- offset_backprop.device(d) = y_backprop.reshape(rest_by_depth)
- .template cast<U>()
- .sum(reduction_axis);
+
+ auto y_backprop_rest_by_depth =
+ y_backprop.reshape(rest_by_depth).template cast<U>();
+ auto input_rest_by_depth = input.reshape(rest_by_depth).template cast<U>();
+
+ offset_backprop.device(d) = y_backprop_rest_by_depth.sum(reduction_axis);
// scratch1 = rsqrt(pop_var + epsilon)
scratch1.device(d) = (pop_var + pop_var.constant(epsilon)).rsqrt();
// scratch2 = sum(y_backprop * (x - mean))
scratch2.device(d) =
- (y_backprop.reshape(rest_by_depth).template cast<U>() *
- (input.reshape(rest_by_depth).template cast<U>() -
+ (y_backprop_rest_by_depth *
+ (input_rest_by_depth -
pop_mean.reshape(one_by_depth).broadcast(rest_by_one)))
.sum(reduction_axis);
x_backprop.reshape(rest_by_depth).device(d) =
- (y_backprop.reshape(rest_by_depth).template cast<U>() *
- ((scratch1 * scale)
- .eval()
- .reshape(one_by_depth)
- .broadcast(rest_by_one)))
+ (y_backprop_rest_by_depth * ((scratch1 * scale)
+ .eval()
+ .reshape(one_by_depth)
+ .broadcast(rest_by_one)))
.template cast<T>();
scale_backprop.device(d) = scratch2 * scratch1;
}