aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/fused_batch_norm_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/fused_batch_norm_op.cc')
-rw-r--r--tensorflow/core/kernels/fused_batch_norm_op.cc70
1 files changed, 39 insertions, 31 deletions
diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cc b/tensorflow/core/kernels/fused_batch_norm_op.cc
index 0ecb829f34..1688674eb7 100644
--- a/tensorflow/core/kernels/fused_batch_norm_op.cc
+++ b/tensorflow/core/kernels/fused_batch_norm_op.cc
@@ -54,25 +54,20 @@ struct FusedBatchNorm<CPUDevice, T, U> {
Tensor* batch_var_output, Tensor* saved_mean_output,
Tensor* saved_var_output, TensorFormat tensor_format,
bool is_training) {
- // Currently U is ignored, since we only support the case where T and U are
- // both float32.
- // TODO(reedwm): Add float16 support, use U, and remove these asserts.
- static_assert(std::is_same<T, float>::value, "T currently must be float.");
- static_assert(std::is_same<U, float>::value, "U currently must be float.");
OP_REQUIRES(context, tensor_format == FORMAT_NHWC,
errors::Internal("The CPU implementation of FusedBatchNorm "
"only supports NHWC tensor format for now."));
typename TTypes<T, 4>::ConstTensor x(x_input.tensor<T, 4>());
- typename TTypes<T>::ConstVec scale(scale_input.vec<T>());
- typename TTypes<T>::ConstVec offset(offset_input.vec<T>());
- typename TTypes<T>::ConstVec estimated_mean(estimated_mean_input.vec<T>());
- typename TTypes<T>::ConstVec estimated_variance(
- estimated_variance_input.vec<T>());
+ typename TTypes<U>::ConstVec scale(scale_input.vec<U>());
+ typename TTypes<U>::ConstVec offset(offset_input.vec<U>());
+ typename TTypes<U>::ConstVec estimated_mean(estimated_mean_input.vec<U>());
+ typename TTypes<U>::ConstVec estimated_variance(
+ estimated_variance_input.vec<U>());
typename TTypes<T, 4>::Tensor y(y_output->tensor<T, 4>());
- typename TTypes<T>::Vec batch_mean(batch_mean_output->vec<T>());
- typename TTypes<T>::Vec batch_var(batch_var_output->vec<T>());
- typename TTypes<T>::Vec saved_mean(saved_mean_output->vec<T>());
- typename TTypes<T>::Vec saved_var(saved_var_output->vec<T>());
+ typename TTypes<U>::Vec batch_mean(batch_mean_output->vec<U>());
+ typename TTypes<U>::Vec batch_var(batch_var_output->vec<U>());
+ typename TTypes<U>::Vec saved_mean(saved_mean_output->vec<U>());
+ typename TTypes<U>::Vec saved_var(saved_var_output->vec<U>());
const CPUDevice& d = context->eigen_device<CPUDevice>();
@@ -93,15 +88,15 @@ struct FusedBatchNorm<CPUDevice, T, U> {
bcast_spec.set(0, rest_size);
#endif
- auto x_rest_by_depth = x.reshape(rest_by_depth);
+ auto x_rest_by_depth = x.reshape(rest_by_depth).template cast<U>();
const int rest_size_minus_one = (rest_size > 1) ? (rest_size - 1) : 1;
- T rest_size_inv = static_cast<T>(1.0f / static_cast<T>(rest_size));
+ U rest_size_inv = static_cast<U>(1.0f / static_cast<U>(rest_size));
// This adjustment is for Bessel's correction
- T rest_size_adjust =
- static_cast<T>(rest_size) / static_cast<T>(rest_size_minus_one);
+ U rest_size_adjust =
+ static_cast<U>(rest_size) / static_cast<U>(rest_size_minus_one);
- Eigen::Tensor<T, 1, Eigen::RowMajor> mean(depth);
- Eigen::Tensor<T, 1, Eigen::RowMajor> variance(depth);
+ Eigen::Tensor<U, 1, Eigen::RowMajor> mean(depth);
+ Eigen::Tensor<U, 1, Eigen::RowMajor> variance(depth);
if (is_training) {
mean.device(d) = (x_rest_by_depth.sum(reduce_dims) * rest_size_inv);
batch_mean.device(d) = mean;
@@ -129,7 +124,7 @@ struct FusedBatchNorm<CPUDevice, T, U> {
auto x_shifted =
x_scaled + offset.reshape(one_by_depth).broadcast(bcast_spec);
- y.reshape(rest_by_depth).device(d) = x_shifted;
+ y.reshape(rest_by_depth).device(d) = x_shifted.template cast<T>();
}
};
@@ -138,7 +133,7 @@ struct FusedBatchNormGrad<CPUDevice, T, U> {
void operator()(OpKernelContext* context, const Tensor& y_backprop_input,
const Tensor& x_input, const Tensor& scale_input,
const Tensor& mean_input, const Tensor& variance_input,
- T epsilon, Tensor* x_backprop_output,
+ U epsilon, Tensor* x_backprop_output,
Tensor* scale_backprop_output, Tensor* offset_backprop_output,
TensorFormat tensor_format) {
OP_REQUIRES(context, tensor_format == FORMAT_NHWC,
@@ -147,12 +142,12 @@ struct FusedBatchNormGrad<CPUDevice, T, U> {
typename TTypes<T, 4>::ConstTensor y_backprop(
y_backprop_input.tensor<T, 4>());
typename TTypes<T, 4>::ConstTensor x(x_input.tensor<T, 4>());
- typename TTypes<T>::ConstVec scale(scale_input.vec<T>());
- typename TTypes<T>::ConstVec mean(mean_input.vec<T>());
- typename TTypes<T>::ConstVec variance(variance_input.vec<T>());
+ typename TTypes<U>::ConstVec scale(scale_input.vec<U>());
+ typename TTypes<U>::ConstVec mean(mean_input.vec<U>());
+ typename TTypes<U>::ConstVec variance(variance_input.vec<U>());
typename TTypes<T, 4>::Tensor x_backprop(x_backprop_output->tensor<T, 4>());
- typename TTypes<T>::Vec scale_backprop(scale_backprop_output->vec<T>());
- typename TTypes<T>::Vec offset_backprop(offset_backprop_output->vec<T>());
+ typename TTypes<U>::Vec scale_backprop(scale_backprop_output->vec<U>());
+ typename TTypes<U>::Vec offset_backprop(offset_backprop_output->vec<U>());
// Note: the following formulas are used to compute the gradients for
// back propagation.
@@ -181,8 +176,8 @@ struct FusedBatchNormGrad<CPUDevice, T, U> {
bcast_spec.set(0, rest_size);
#endif
- auto x_rest_by_depth = x.reshape(rest_by_depth);
- T rest_size_inv = static_cast<T>(1.0f / static_cast<T>(rest_size));
+ auto x_rest_by_depth = x.reshape(rest_by_depth).template cast<U>();
+ U rest_size_inv = static_cast<U>(1.0f / static_cast<U>(rest_size));
auto x_mean_rest_by_depth =
mean.reshape(one_by_depth).broadcast(bcast_spec);
@@ -192,7 +187,8 @@ struct FusedBatchNormGrad<CPUDevice, T, U> {
coef0.eval().reshape(one_by_depth).broadcast(bcast_spec);
auto x_scaled = x_centered * coef0_rest_by_depth;
- auto y_backprop_rest_by_depth = y_backprop.eval().reshape(rest_by_depth);
+ auto y_backprop_rest_by_depth =
+ y_backprop.eval().reshape(rest_by_depth).template cast<U>();
scale_backprop.device(d) =
(y_backprop_rest_by_depth * x_scaled).sum(reduce_dims);
auto y_backprop_sum = y_backprop_rest_by_depth.sum(reduce_dims);
@@ -214,7 +210,7 @@ struct FusedBatchNormGrad<CPUDevice, T, U> {
.reshape(one_by_depth)
.broadcast(bcast_spec);
x_backprop.reshape(rest_by_depth).device(d) =
- coef1 * (y_backprop_centered - x_centered * coef2);
+ (coef1 * (y_backprop_centered - x_centered * coef2)).template cast<T>();
}
};
@@ -689,6 +685,18 @@ REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV2")
.TypeConstraint<float>("U"),
FusedBatchNormGradOp<CPUDevice, float, float>);
+REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV2")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<Eigen::half>("T")
+ .TypeConstraint<float>("U"),
+ FusedBatchNormOp<CPUDevice, Eigen::half, float>);
+
+REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV2")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<Eigen::half>("T")
+ .TypeConstraint<float>("U"),
+ FusedBatchNormGradOp<CPUDevice, Eigen::half, float>);
+
#if GOOGLE_CUDA
REGISTER_KERNEL_BUILDER(