aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/fused_batch_norm_op.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/fused_batch_norm_op.h')
-rw-r--r--tensorflow/core/kernels/fused_batch_norm_op.h22
1 files changed, 10 insertions, 12 deletions
diff --git a/tensorflow/core/kernels/fused_batch_norm_op.h b/tensorflow/core/kernels/fused_batch_norm_op.h
index 3af104bf95..38b24d7011 100644
--- a/tensorflow/core/kernels/fused_batch_norm_op.h
+++ b/tensorflow/core/kernels/fused_batch_norm_op.h
@@ -92,28 +92,26 @@ struct FusedBatchNormFreezeGrad {
// offset_backprop = sum(y_backprop)
// scale_backprop = y_backprop * ((x - pop_mean) * rsqrt(pop_var + epsilon))
// x_backprop = y_backprop * (scale * rsqrt(pop_var + epsilon))
-
- auto y_backprop_rest_by_depth =
- y_backprop.reshape(rest_by_depth).template cast<U>();
- auto input_rest_by_depth = input.reshape(rest_by_depth).template cast<U>();
-
- offset_backprop.device(d) = y_backprop_rest_by_depth.sum(reduction_axis);
+ offset_backprop.device(d) = y_backprop.reshape(rest_by_depth)
+ .template cast<U>()
+ .sum(reduction_axis);
// scratch1 = rsqrt(pop_var + epsilon)
scratch1.device(d) = (pop_var + pop_var.constant(epsilon)).rsqrt();
// scratch2 = sum(y_backprop * (x - mean))
scratch2.device(d) =
- (y_backprop_rest_by_depth *
- (input_rest_by_depth -
+ (y_backprop.reshape(rest_by_depth).template cast<U>() *
+ (input.reshape(rest_by_depth).template cast<U>() -
pop_mean.reshape(one_by_depth).broadcast(rest_by_one)))
.sum(reduction_axis);
x_backprop.reshape(rest_by_depth).device(d) =
- (y_backprop_rest_by_depth * ((scratch1 * scale)
- .eval()
- .reshape(one_by_depth)
- .broadcast(rest_by_one)))
+ (y_backprop.reshape(rest_by_depth).template cast<U>() *
+ ((scratch1 * scale)
+ .eval()
+ .reshape(one_by_depth)
+ .broadcast(rest_by_one)))
.template cast<T>();
scale_backprop.device(d) = scratch2 * scratch1;
}