aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/fused_batch_norm_op.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/fused_batch_norm_op.h')
-rw-r--r--tensorflow/core/kernels/fused_batch_norm_op.h71
1 files changed, 71 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/fused_batch_norm_op.h b/tensorflow/core/kernels/fused_batch_norm_op.h
index da8692caad..1566cfa4dc 100644
--- a/tensorflow/core/kernels/fused_batch_norm_op.h
+++ b/tensorflow/core/kernels/fused_batch_norm_op.h
@@ -17,9 +17,14 @@ limitations under the License.
#define TENSORFLOW_KERNELS_FUSED_BATCH_NORM_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_types.h"
namespace tensorflow {
namespace functor {
+
+#if GOOGLE_CUDA
+
// There is a behavior difference between cuDNN v4 and v5 with regard to the
// scaling factor for function cudnnBatchNormalizationForwardInference.
// This function corrects the scaling factor if cuDNN v4 is used, so that
@@ -43,6 +48,72 @@ struct InvVarianceToVariance {
void operator()(const Eigen::GpuDevice& d, double epsilon, int sample_size,
int channels, T* variance);
};
+
+#endif // GOOGLE_CUDA
+
+// Functor used by FusedBatchNormGradOp to do the computations when
+// is_training=False. Both CPU and GPU will use this functor.
+template <typename Device, typename T>
+struct FusedBatchNormFreezeGrad {
+ void operator()(const Device& d, const Tensor& y_backprop_input,
+ const Tensor& x_input, const Tensor& scale_input,
+ const Tensor& pop_mean_input,
+ const Tensor& pop_variance_input, T epsilon,
+ Tensor* x_backprop_output, Tensor* scale_backprop_output,
+ Tensor* offset_backprop_output,
+ typename TTypes<T>::Vec scratch1,
+ typename TTypes<T>::Vec scratch2) {
+ typename TTypes<T, 4>::ConstTensor y_backprop(
+ y_backprop_input.tensor<T, 4>());
+ typename TTypes<T, 4>::ConstTensor input(x_input.tensor<T, 4>());
+ typename TTypes<T>::ConstVec scale(scale_input.vec<T>());
+ typename TTypes<T>::ConstVec pop_mean(pop_mean_input.vec<T>());
+ typename TTypes<T>::ConstVec pop_var(pop_variance_input.vec<T>());
+ typename TTypes<T, 4>::Tensor x_backprop(x_backprop_output->tensor<T, 4>());
+ typename TTypes<T>::Vec scale_backprop(scale_backprop_output->vec<T>());
+ typename TTypes<T>::Vec offset_backprop(offset_backprop_output->vec<T>());
+
+ const int depth = pop_mean.dimension(0);
+ const int rest_size = input.size() / depth;
+
+ Eigen::DSizes<Eigen::Index, 2> rest_by_depth(rest_size, depth);
+#if !defined(EIGEN_HAS_INDEX_LIST)
+ Eigen::DSizes<Eigen::Index, 2> one_by_depth(1, depth);
+ Eigen::array<int, 1> reduction_axis{0};
+ Eigen::array<int, 2> rest_by_one({rest_size, 1});
+#else
+ Eigen::IndexList<Eigen::type2index<1>, Eigen::Index> one_by_depth;
+ one_by_depth.set(1, depth);
+ Eigen::IndexList<Eigen::type2index<0> > reduction_axis;
+ Eigen::IndexList<Eigen::Index, Eigen::type2index<1> > rest_by_one;
+ rest_by_one.set(0, rest_size);
+#endif
+
+ // offset_backprop = sum(y_backprop)
+ // scale_backprop = y_backprop * ((x - pop_mean) * rsqrt(pop_var + epsilon))
+ // x_backprop = y_backprop * (scale * rsqrt(pop_var + epsilon))
+ offset_backprop.device(d) =
+ y_backprop.reshape(rest_by_depth).sum(reduction_axis);
+
+ // scratch1 = rsqrt(pop_var + epsilon)
+ scratch1.device(d) = (pop_var + pop_var.constant(epsilon)).rsqrt();
+
+ // scratch2 = sum(y_backprop * (x - mean))
+ scratch2.device(d) =
+ (y_backprop.reshape(rest_by_depth) *
+ (input.reshape(rest_by_depth) -
+ pop_mean.reshape(one_by_depth).broadcast(rest_by_one)))
+ .sum(reduction_axis);
+
+ x_backprop.reshape(rest_by_depth).device(d) =
+ y_backprop.reshape(rest_by_depth) * ((scratch1 * scale)
+ .eval()
+ .reshape(one_by_depth)
+ .broadcast(rest_by_one));
+ scale_backprop.device(d) = scratch2 * scratch1;
+ }
+};
+
} // namespace functor
} // namespace tensorflow