diff options
Diffstat (limited to 'tensorflow/core/kernels/fused_batch_norm_op.cc')
-rw-r--r-- | tensorflow/core/kernels/fused_batch_norm_op.cc | 70 |
1 files changed, 39 insertions, 31 deletions
diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cc b/tensorflow/core/kernels/fused_batch_norm_op.cc index 0ecb829f34..1688674eb7 100644 --- a/tensorflow/core/kernels/fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/fused_batch_norm_op.cc @@ -54,25 +54,20 @@ struct FusedBatchNorm<CPUDevice, T, U> { Tensor* batch_var_output, Tensor* saved_mean_output, Tensor* saved_var_output, TensorFormat tensor_format, bool is_training) { - // Currently U is ignored, since we only support the case where T and U are - // both float32. - // TODO(reedwm): Add float16 support, use U, and remove these asserts. - static_assert(std::is_same<T, float>::value, "T currently must be float."); - static_assert(std::is_same<U, float>::value, "U currently must be float."); OP_REQUIRES(context, tensor_format == FORMAT_NHWC, errors::Internal("The CPU implementation of FusedBatchNorm " "only supports NHWC tensor format for now.")); typename TTypes<T, 4>::ConstTensor x(x_input.tensor<T, 4>()); - typename TTypes<T>::ConstVec scale(scale_input.vec<T>()); - typename TTypes<T>::ConstVec offset(offset_input.vec<T>()); - typename TTypes<T>::ConstVec estimated_mean(estimated_mean_input.vec<T>()); - typename TTypes<T>::ConstVec estimated_variance( - estimated_variance_input.vec<T>()); + typename TTypes<U>::ConstVec scale(scale_input.vec<U>()); + typename TTypes<U>::ConstVec offset(offset_input.vec<U>()); + typename TTypes<U>::ConstVec estimated_mean(estimated_mean_input.vec<U>()); + typename TTypes<U>::ConstVec estimated_variance( + estimated_variance_input.vec<U>()); typename TTypes<T, 4>::Tensor y(y_output->tensor<T, 4>()); - typename TTypes<T>::Vec batch_mean(batch_mean_output->vec<T>()); - typename TTypes<T>::Vec batch_var(batch_var_output->vec<T>()); - typename TTypes<T>::Vec saved_mean(saved_mean_output->vec<T>()); - typename TTypes<T>::Vec saved_var(saved_var_output->vec<T>()); + typename TTypes<U>::Vec batch_mean(batch_mean_output->vec<U>()); + typename TTypes<U>::Vec batch_var(batch_var_output->vec<U>()); + typename TTypes<U>::Vec saved_mean(saved_mean_output->vec<U>()); + typename TTypes<U>::Vec saved_var(saved_var_output->vec<U>()); const CPUDevice& d = context->eigen_device<CPUDevice>(); @@ -93,15 +88,15 @@ struct FusedBatchNorm<CPUDevice, T, U> { bcast_spec.set(0, rest_size); #endif - auto x_rest_by_depth = x.reshape(rest_by_depth); + auto x_rest_by_depth = x.reshape(rest_by_depth).template cast<U>(); const int rest_size_minus_one = (rest_size > 1) ? (rest_size - 1) : 1; - T rest_size_inv = static_cast<T>(1.0f / static_cast<T>(rest_size)); + U rest_size_inv = static_cast<U>(1.0f / static_cast<U>(rest_size)); // This adjustment is for Bessel's correction - T rest_size_adjust = - static_cast<T>(rest_size) / static_cast<T>(rest_size_minus_one); + U rest_size_adjust = + static_cast<U>(rest_size) / static_cast<U>(rest_size_minus_one); - Eigen::Tensor<T, 1, Eigen::RowMajor> mean(depth); - Eigen::Tensor<T, 1, Eigen::RowMajor> variance(depth); + Eigen::Tensor<U, 1, Eigen::RowMajor> mean(depth); + Eigen::Tensor<U, 1, Eigen::RowMajor> variance(depth); if (is_training) { mean.device(d) = (x_rest_by_depth.sum(reduce_dims) * rest_size_inv); batch_mean.device(d) = mean; @@ -129,7 +124,7 @@ struct FusedBatchNorm<CPUDevice, T, U> { auto x_shifted = x_scaled + offset.reshape(one_by_depth).broadcast(bcast_spec); - y.reshape(rest_by_depth).device(d) = x_shifted; + y.reshape(rest_by_depth).device(d) = x_shifted.template cast<T>(); } }; @@ -138,7 +133,7 @@ struct FusedBatchNormGrad<CPUDevice, T, U> { void operator()(OpKernelContext* context, const Tensor& y_backprop_input, const Tensor& x_input, const Tensor& scale_input, const Tensor& mean_input, const Tensor& variance_input, - T epsilon, Tensor* x_backprop_output, + U epsilon, Tensor* x_backprop_output, Tensor* scale_backprop_output, Tensor* offset_backprop_output, TensorFormat tensor_format) { OP_REQUIRES(context, tensor_format == FORMAT_NHWC, @@ -147,12 +142,12 @@ struct FusedBatchNormGrad<CPUDevice, T, U> { typename TTypes<T, 4>::ConstTensor y_backprop( y_backprop_input.tensor<T, 4>()); typename TTypes<T, 4>::ConstTensor x(x_input.tensor<T, 4>()); - typename TTypes<T>::ConstVec scale(scale_input.vec<T>()); - typename TTypes<T>::ConstVec mean(mean_input.vec<T>()); - typename TTypes<T>::ConstVec variance(variance_input.vec<T>()); + typename TTypes<U>::ConstVec scale(scale_input.vec<U>()); + typename TTypes<U>::ConstVec mean(mean_input.vec<U>()); + typename TTypes<U>::ConstVec variance(variance_input.vec<U>()); typename TTypes<T, 4>::Tensor x_backprop(x_backprop_output->tensor<T, 4>()); - typename TTypes<T>::Vec scale_backprop(scale_backprop_output->vec<T>()); - typename TTypes<T>::Vec offset_backprop(offset_backprop_output->vec<T>()); + typename TTypes<U>::Vec scale_backprop(scale_backprop_output->vec<U>()); + typename TTypes<U>::Vec offset_backprop(offset_backprop_output->vec<U>()); // Note: the following formulas are used to compute the gradients for // back propagation. @@ -181,8 +176,8 @@ struct FusedBatchNormGrad<CPUDevice, T, U> { bcast_spec.set(0, rest_size); #endif - auto x_rest_by_depth = x.reshape(rest_by_depth); - T rest_size_inv = static_cast<T>(1.0f / static_cast<T>(rest_size)); + auto x_rest_by_depth = x.reshape(rest_by_depth).template cast<U>(); + U rest_size_inv = static_cast<U>(1.0f / static_cast<U>(rest_size)); auto x_mean_rest_by_depth = mean.reshape(one_by_depth).broadcast(bcast_spec); @@ -192,7 +187,8 @@ struct FusedBatchNormGrad<CPUDevice, T, U> { coef0.eval().reshape(one_by_depth).broadcast(bcast_spec); auto x_scaled = x_centered * coef0_rest_by_depth; - auto y_backprop_rest_by_depth = y_backprop.eval().reshape(rest_by_depth); + auto y_backprop_rest_by_depth = + y_backprop.eval().reshape(rest_by_depth).template cast<U>(); scale_backprop.device(d) = (y_backprop_rest_by_depth * x_scaled).sum(reduce_dims); auto y_backprop_sum = y_backprop_rest_by_depth.sum(reduce_dims); @@ -214,7 +210,7 @@ struct FusedBatchNormGrad<CPUDevice, T, U> { .reshape(one_by_depth) .broadcast(bcast_spec); x_backprop.reshape(rest_by_depth).device(d) = - coef1 * (y_backprop_centered - x_centered * coef2); + (coef1 * (y_backprop_centered - x_centered * coef2)).template cast<T>(); } }; @@ -689,6 +685,18 @@ REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV2") .TypeConstraint<float>("U"), FusedBatchNormGradOp<CPUDevice, float, float>); +REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV2") + .Device(DEVICE_CPU) + .TypeConstraint<Eigen::half>("T") + .TypeConstraint<float>("U"), + FusedBatchNormOp<CPUDevice, Eigen::half, float>); + +REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV2") + .Device(DEVICE_CPU) + .TypeConstraint<Eigen::half>("T") + .TypeConstraint<float>("U"), + FusedBatchNormGradOp<CPUDevice, Eigen::half, float>); + #if GOOGLE_CUDA REGISTER_KERNEL_BUILDER( |