diff options
Diffstat (limited to 'tensorflow/core/kernels/fused_batch_norm_op.h')
-rw-r--r-- | tensorflow/core/kernels/fused_batch_norm_op.h | 22 |
1 files changed, 10 insertions, 12 deletions
diff --git a/tensorflow/core/kernels/fused_batch_norm_op.h b/tensorflow/core/kernels/fused_batch_norm_op.h index 3af104bf95..38b24d7011 100644 --- a/tensorflow/core/kernels/fused_batch_norm_op.h +++ b/tensorflow/core/kernels/fused_batch_norm_op.h @@ -92,28 +92,26 @@ struct FusedBatchNormFreezeGrad { // offset_backprop = sum(y_backprop) // scale_backprop = y_backprop * ((x - pop_mean) * rsqrt(pop_var + epsilon)) // x_backprop = y_backprop * (scale * rsqrt(pop_var + epsilon)) - - auto y_backprop_rest_by_depth = - y_backprop.reshape(rest_by_depth).template cast<U>(); - auto input_rest_by_depth = input.reshape(rest_by_depth).template cast<U>(); - - offset_backprop.device(d) = y_backprop_rest_by_depth.sum(reduction_axis); + offset_backprop.device(d) = y_backprop.reshape(rest_by_depth) + .template cast<U>() + .sum(reduction_axis); // scratch1 = rsqrt(pop_var + epsilon) scratch1.device(d) = (pop_var + pop_var.constant(epsilon)).rsqrt(); // scratch2 = sum(y_backprop * (x - mean)) scratch2.device(d) = - (y_backprop_rest_by_depth * - (input_rest_by_depth - + (y_backprop.reshape(rest_by_depth).template cast<U>() * + (input.reshape(rest_by_depth).template cast<U>() - pop_mean.reshape(one_by_depth).broadcast(rest_by_one))) .sum(reduction_axis); x_backprop.reshape(rest_by_depth).device(d) = - (y_backprop_rest_by_depth * ((scratch1 * scale) - .eval() - .reshape(one_by_depth) - .broadcast(rest_by_one))) + (y_backprop.reshape(rest_by_depth).template cast<U>() * + ((scratch1 * scale) + .eval() + .reshape(one_by_depth) + .broadcast(rest_by_one))) .template cast<T>(); scale_backprop.device(d) = scratch2 * scratch1; } |