diff options
Diffstat (limited to 'tensorflow/core/kernels/fused_batch_norm_op.h')
-rw-r--r-- | tensorflow/core/kernels/fused_batch_norm_op.h | 22 |
1 files changed, 12 insertions, 10 deletions
diff --git a/tensorflow/core/kernels/fused_batch_norm_op.h b/tensorflow/core/kernels/fused_batch_norm_op.h index 38b24d7011..3af104bf95 100644 --- a/tensorflow/core/kernels/fused_batch_norm_op.h +++ b/tensorflow/core/kernels/fused_batch_norm_op.h @@ -92,26 +92,28 @@ struct FusedBatchNormFreezeGrad { // offset_backprop = sum(y_backprop) // scale_backprop = y_backprop * ((x - pop_mean) * rsqrt(pop_var + epsilon)) // x_backprop = y_backprop * (scale * rsqrt(pop_var + epsilon)) - offset_backprop.device(d) = y_backprop.reshape(rest_by_depth) - .template cast<U>() - .sum(reduction_axis); + + auto y_backprop_rest_by_depth = + y_backprop.reshape(rest_by_depth).template cast<U>(); + auto input_rest_by_depth = input.reshape(rest_by_depth).template cast<U>(); + + offset_backprop.device(d) = y_backprop_rest_by_depth.sum(reduction_axis); // scratch1 = rsqrt(pop_var + epsilon) scratch1.device(d) = (pop_var + pop_var.constant(epsilon)).rsqrt(); // scratch2 = sum(y_backprop * (x - mean)) scratch2.device(d) = - (y_backprop.reshape(rest_by_depth).template cast<U>() * - (input.reshape(rest_by_depth).template cast<U>() - + (y_backprop_rest_by_depth * + (input_rest_by_depth - pop_mean.reshape(one_by_depth).broadcast(rest_by_one))) .sum(reduction_axis); x_backprop.reshape(rest_by_depth).device(d) = - (y_backprop.reshape(rest_by_depth).template cast<U>() * - ((scratch1 * scale) - .eval() - .reshape(one_by_depth) - .broadcast(rest_by_one))) + (y_backprop_rest_by_depth * ((scratch1 * scale) + .eval() + .reshape(one_by_depth) + .broadcast(rest_by_one))) .template cast<T>(); scale_backprop.device(d) = scratch2 * scratch1; } |