diff options
author | Reed Wanderman-Milne <reedwm@google.com> | 2017-09-27 12:58:14 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-09-27 13:04:57 -0700 |
commit | 759690f026a1a08b3ac5cc84d8498c05c32b2a7d (patch) | |
tree | 9c7ba12fef51b97226f4e0a07b9aa0eff7fccff1 /tensorflow/core/kernels/fused_batch_norm_op.cc | |
parent | 20370104cd8adf4c3f9068dfe95bde54cccadfa5 (diff) |
Add float16 support to tf.nn.fused_batch_norm on the GPU.
Scale, offset, mean, and variance must still be float32 if the input is float16.
PiperOrigin-RevId: 170239448
Diffstat (limited to 'tensorflow/core/kernels/fused_batch_norm_op.cc')
-rw-r--r-- | tensorflow/core/kernels/fused_batch_norm_op.cc | 179 |
1 files changed, 114 insertions, 65 deletions
diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cc b/tensorflow/core/kernels/fused_batch_norm_op.cc index 92b093eec6..0ecb829f34 100644 --- a/tensorflow/core/kernels/fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/fused_batch_norm_op.cc @@ -37,23 +37,28 @@ using GPUDevice = Eigen::GpuDevice; namespace functor { // Functor used by FusedBatchNormOp to do the computations. -template <typename Device, typename T> +template <typename Device, typename T, typename U> struct FusedBatchNorm; // Functor used by FusedBatchNormGradOp to do the computations when // is_training=True. -template <typename Device, typename T> +template <typename Device, typename T, typename U> struct FusedBatchNormGrad; -template <typename T> -struct FusedBatchNorm<CPUDevice, T> { +template <typename T, typename U> +struct FusedBatchNorm<CPUDevice, T, U> { void operator()(OpKernelContext* context, const Tensor& x_input, const Tensor& scale_input, const Tensor& offset_input, const Tensor& estimated_mean_input, - const Tensor& estimated_variance_input, T epsilon, + const Tensor& estimated_variance_input, U epsilon, Tensor* y_output, Tensor* batch_mean_output, Tensor* batch_var_output, Tensor* saved_mean_output, Tensor* saved_var_output, TensorFormat tensor_format, bool is_training) { + // Currently U is ignored, since we only support the case where T and U are + // both float32. + // TODO(reedwm): Add float16 support, use U, and remove these asserts. + static_assert(std::is_same<T, float>::value, "T currently must be float."); + static_assert(std::is_same<U, float>::value, "U currently must be float."); OP_REQUIRES(context, tensor_format == FORMAT_NHWC, errors::Internal("The CPU implementation of FusedBatchNorm " "only supports NHWC tensor format for now.")); @@ -128,8 +133,8 @@ struct FusedBatchNorm<CPUDevice, T> { } }; -template <typename T> -struct FusedBatchNormGrad<CPUDevice, T> { +template <typename T, typename U> +struct FusedBatchNormGrad<CPUDevice, T, U> { void operator()(OpKernelContext* context, const Tensor& y_backprop_input, const Tensor& x_input, const Tensor& scale_input, const Tensor& mean_input, const Tensor& variance_input, @@ -214,12 +219,12 @@ struct FusedBatchNormGrad<CPUDevice, T> { }; #if GOOGLE_CUDA -template <typename T> -struct FusedBatchNorm<GPUDevice, T> { +template <typename T, typename U> +struct FusedBatchNorm<GPUDevice, T, U> { void operator()(OpKernelContext* context, const Tensor& x, const Tensor& scale, const Tensor& offset, const Tensor& estimated_mean, - const Tensor& estimated_variance, T epsilon, Tensor* y, + const Tensor& estimated_variance, U epsilon, Tensor* y, Tensor* batch_mean, Tensor* batch_var, Tensor* saved_mean, Tensor* saved_inv_var, TensorFormat tensor_format, bool is_training) { @@ -284,44 +289,44 @@ struct FusedBatchNorm<GPUDevice, T> { .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); auto x_ptr = StreamExecutorUtil::AsDeviceMemory<T>(x_maybe_transformed); - auto scale_ptr = StreamExecutorUtil::AsDeviceMemory<T>(scale); - auto offset_ptr = StreamExecutorUtil::AsDeviceMemory<T>(offset); + auto scale_ptr = StreamExecutorUtil::AsDeviceMemory<U>(scale); + auto offset_ptr = StreamExecutorUtil::AsDeviceMemory<U>(offset); auto estimated_mean_ptr = - StreamExecutorUtil::AsDeviceMemory<T>(estimated_mean); + StreamExecutorUtil::AsDeviceMemory<U>(estimated_mean); auto estimated_variance_ptr = - StreamExecutorUtil::AsDeviceMemory<T>(estimated_variance); - auto batch_mean_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*batch_mean); + StreamExecutorUtil::AsDeviceMemory<U>(estimated_variance); + auto batch_mean_ptr = StreamExecutorUtil::AsDeviceMemory<U>(*batch_mean); - auto batch_var_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*batch_var); - auto saved_mean_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*saved_mean); + auto batch_var_ptr = StreamExecutorUtil::AsDeviceMemory<U>(*batch_var); + auto saved_mean_ptr = StreamExecutorUtil::AsDeviceMemory<U>(*saved_mean); auto saved_inv_var_ptr = - StreamExecutorUtil::AsDeviceMemory<T>(*saved_inv_var); + StreamExecutorUtil::AsDeviceMemory<U>(*saved_inv_var); GPUDevice d = context->eigen_device<GPUDevice>(); using perftools::gputools::DeviceMemory; Tensor inv_var; OP_REQUIRES_OK( - context, context->allocate_temp(DataTypeToEnum<T>::value, + context, context->allocate_temp(DataTypeToEnum<U>::value, estimated_variance.shape(), &inv_var)); - auto inv_var_ptr = StreamExecutorUtil::AsDeviceMemory<T>(inv_var); - std::function<const DeviceMemory<T>&()> var_to_inv_var = + auto inv_var_ptr = StreamExecutorUtil::AsDeviceMemory<U>(inv_var); + std::function<const DeviceMemory<U>&()> var_to_inv_var = [d, epsilon, estimated_variance, - &inv_var_ptr]() -> const DeviceMemory<T>& { + &inv_var_ptr]() -> const DeviceMemory<U>& { auto estimated_variance_ptr = - StreamExecutorUtil::AsDeviceMemory<T>(estimated_variance); - const T* variance = - static_cast<const T*>(estimated_variance_ptr.opaque()); - T* inv_variance = static_cast<T*>(inv_var_ptr.opaque()); + StreamExecutorUtil::AsDeviceMemory<U>(estimated_variance); + const U* variance = + static_cast<const U*>(estimated_variance_ptr.opaque()); + U* inv_variance = static_cast<U*>(inv_var_ptr.opaque()); int channels = inv_var_ptr.ElementCount(); - VarianceToInvVariance<T>()(d, variance, epsilon, channels, inv_variance); + VarianceToInvVariance<U>()(d, variance, epsilon, channels, inv_variance); return inv_var_ptr; }; const int64 sample_size = batch_size * height * width; std::function<void()> inv_var_to_var = [d, &batch_var_ptr, epsilon, sample_size]() { - T* variance = static_cast<T*>(batch_var_ptr.opaque()); + U* variance = static_cast<U*>(batch_var_ptr.opaque()); int channels = batch_var_ptr.ElementCount(); - InvVarianceToVariance<T>()(d, epsilon, sample_size, channels, variance); + InvVarianceToVariance<U>()(d, epsilon, sample_size, channels, variance); }; bool cudnn_launch_status = @@ -349,11 +354,11 @@ struct FusedBatchNorm<GPUDevice, T> { } }; -template <typename T> -struct FusedBatchNormGrad<GPUDevice, T> { +template <typename T, typename U> +struct FusedBatchNormGrad<GPUDevice, T, U> { void operator()(OpKernelContext* context, const Tensor& y_backprop, const Tensor& x, const Tensor& scale, const Tensor& mean, - const Tensor& inv_variance, T epsilon, Tensor* x_backprop, + const Tensor& inv_variance, U epsilon, Tensor* x_backprop, Tensor* scale_backprop, Tensor* offset_backprop, TensorFormat tensor_format) { auto* stream = context->op_device_context()->stream(); @@ -440,13 +445,13 @@ struct FusedBatchNormGrad<GPUDevice, T> { auto y_backprop_ptr = StreamExecutorUtil::AsDeviceMemory<T>(y_backprop_maybe_transformed); auto x_ptr = StreamExecutorUtil::AsDeviceMemory<T>(x_maybe_transformed); - auto scale_ptr = StreamExecutorUtil::AsDeviceMemory<T>(scale); - auto mean_ptr = StreamExecutorUtil::AsDeviceMemory<T>(mean); - auto inv_variance_ptr = StreamExecutorUtil::AsDeviceMemory<T>(inv_variance); + auto scale_ptr = StreamExecutorUtil::AsDeviceMemory<U>(scale); + auto mean_ptr = StreamExecutorUtil::AsDeviceMemory<U>(mean); + auto inv_variance_ptr = StreamExecutorUtil::AsDeviceMemory<U>(inv_variance); auto scale_backprop_ptr = - StreamExecutorUtil::AsDeviceMemory<T>(*scale_backprop); + StreamExecutorUtil::AsDeviceMemory<U>(*scale_backprop); auto offset_backprop_ptr = - StreamExecutorUtil::AsDeviceMemory<T>(*offset_backprop); + StreamExecutorUtil::AsDeviceMemory<U>(*offset_backprop); // the cudnn kernel outputs inverse variance in forward and reuse it in // backward @@ -473,28 +478,29 @@ struct FusedBatchNormGrad<GPUDevice, T> { }; // Forward declarations of the functor specializations for GPU. -#define DECLARE_GPU_SPEC(T) \ +#define DECLARE_GPU_SPEC(T, U) \ template <> \ - void FusedBatchNormFreezeGrad<GPUDevice, T>::operator()( \ + void FusedBatchNormFreezeGrad<GPUDevice, T, U>::operator()( \ const GPUDevice& d, const Tensor& y_backprop_input, \ const Tensor& x_input, const Tensor& scale_input, \ - const Tensor& mean_input, const Tensor& variance_input, T epsilon, \ + const Tensor& mean_input, const Tensor& variance_input, U epsilon, \ Tensor* x_backprop_output, Tensor* scale_backprop_output, \ - Tensor* offset_backprop_output, typename TTypes<T>::Vec scratch1, \ - typename TTypes<T>::Vec scratch2); \ - extern template struct FusedBatchNormFreezeGrad<GPUDevice, T>; -DECLARE_GPU_SPEC(float); + Tensor* offset_backprop_output, typename TTypes<U>::Vec scratch1, \ + typename TTypes<U>::Vec scratch2); \ + extern template struct FusedBatchNormFreezeGrad<GPUDevice, T, U>; +DECLARE_GPU_SPEC(float, float); +DECLARE_GPU_SPEC(Eigen::half, float); #endif // GOOGLE_CUDA } // namespace functor -template <typename Device, typename T> +template <typename Device, typename T, typename U> class FusedBatchNormOp : public OpKernel { public: explicit FusedBatchNormOp(OpKernelConstruction* context) : OpKernel(context) { float epsilon; OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); - epsilon_ = T(epsilon); + epsilon_ = U(epsilon); string tensor_format; OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format)); OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_), @@ -552,26 +558,26 @@ class FusedBatchNormOp : public OpKernel { OP_REQUIRES_OK(context, context->allocate_output(4, scale.shape(), &saved_maybe_inv_var)); - functor::FusedBatchNorm<Device, T>()( + functor::FusedBatchNorm<Device, T, U>()( context, x, scale, offset, estimated_mean, estimated_variance, epsilon_, y, batch_mean, batch_var, saved_mean, saved_maybe_inv_var, tensor_format_, is_training_); } private: - T epsilon_; + U epsilon_; TensorFormat tensor_format_; bool is_training_; }; -template <typename Device, typename T> +template <typename Device, typename T, typename U> class FusedBatchNormGradOp : public OpKernel { public: explicit FusedBatchNormGradOp(OpKernelConstruction* context) : OpKernel(context) { float epsilon; OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); - epsilon_ = T(epsilon); + epsilon_ = U(epsilon); string tensor_format; OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format)); OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_), @@ -631,7 +637,7 @@ class FusedBatchNormGradOp : public OpKernel { context, context->allocate_output(4, TensorShape({}), &placeholder_2)); if (is_training_) { - functor::FusedBatchNormGrad<Device, T>()( + functor::FusedBatchNormGrad<Device, T, U>()( context, y_backprop, x, scale, saved_mean_or_pop_mean, saved_maybe_inv_var_or_pop_var, epsilon_, x_backprop, scale_backprop, offset_backprop, tensor_format_); @@ -644,36 +650,79 @@ class FusedBatchNormGradOp : public OpKernel { << "NHWC tensor format for now."; Tensor scratch1, scratch2; OP_REQUIRES_OK(context, - context->allocate_temp(DataTypeToEnum<T>::value, + context->allocate_temp(DataTypeToEnum<U>::value, scale_offset_shape, &scratch1)); OP_REQUIRES_OK(context, - context->allocate_temp(DataTypeToEnum<T>::value, + context->allocate_temp(DataTypeToEnum<U>::value, scale_offset_shape, &scratch2)); - functor::FusedBatchNormFreezeGrad<Device, T>()( + functor::FusedBatchNormFreezeGrad<Device, T, U>()( context->eigen_device<Device>(), y_backprop, x, scale, saved_mean_or_pop_mean, saved_maybe_inv_var_or_pop_var, epsilon_, - x_backprop, scale_backprop, offset_backprop, scratch1.vec<T>(), - scratch2.vec<T>()); + x_backprop, scale_backprop, offset_backprop, scratch1.vec<U>(), + scratch2.vec<U>()); } } private: - T epsilon_; + U epsilon_; TensorFormat tensor_format_; bool is_training_; }; -REGISTER_KERNEL_BUILDER(Name("FusedBatchNorm").Device(DEVICE_CPU), - FusedBatchNormOp<CPUDevice, float>); +REGISTER_KERNEL_BUILDER( + Name("FusedBatchNorm").Device(DEVICE_CPU).TypeConstraint<float>("T"), + FusedBatchNormOp<CPUDevice, float, float>); + +REGISTER_KERNEL_BUILDER( + Name("FusedBatchNormGrad").Device(DEVICE_CPU).TypeConstraint<float>("T"), + FusedBatchNormGradOp<CPUDevice, float, float>); + +REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV2") + .Device(DEVICE_CPU) + .TypeConstraint<float>("T") + .TypeConstraint<float>("U"), + FusedBatchNormOp<CPUDevice, float, float>); + +REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV2") + .Device(DEVICE_CPU) + .TypeConstraint<float>("T") + .TypeConstraint<float>("U"), + FusedBatchNormGradOp<CPUDevice, float, float>); -REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGrad").Device(DEVICE_CPU), - FusedBatchNormGradOp<CPUDevice, float>); #if GOOGLE_CUDA -REGISTER_KERNEL_BUILDER(Name("FusedBatchNorm").Device(DEVICE_GPU), - FusedBatchNormOp<GPUDevice, float>); -REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGrad").Device(DEVICE_GPU), - FusedBatchNormGradOp<GPUDevice, float>); +REGISTER_KERNEL_BUILDER( + Name("FusedBatchNorm").Device(DEVICE_GPU).TypeConstraint<float>("T"), + FusedBatchNormOp<GPUDevice, float, float>); + +REGISTER_KERNEL_BUILDER( + Name("FusedBatchNormGrad").Device(DEVICE_GPU).TypeConstraint<float>("T"), + FusedBatchNormGradOp<GPUDevice, float, float>); + +REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV2") + .Device(DEVICE_GPU) + .TypeConstraint<float>("T") + .TypeConstraint<float>("U"), + FusedBatchNormOp<GPUDevice, float, float>); + +REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV2") + .Device(DEVICE_GPU) + .TypeConstraint<float>("T") + .TypeConstraint<float>("U"), + FusedBatchNormGradOp<GPUDevice, float, float>); + +REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV2") + .Device(DEVICE_GPU) + .TypeConstraint<Eigen::half>("T") + .TypeConstraint<float>("U"), + FusedBatchNormOp<GPUDevice, Eigen::half, float>); + +REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV2") + .Device(DEVICE_GPU) + .TypeConstraint<Eigen::half>("T") + .TypeConstraint<float>("U"), + FusedBatchNormGradOp<GPUDevice, Eigen::half, float>); + #endif } // namespace tensorflow |