aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/fused_batch_norm_op.cc
diff options
context:
space:
mode:
authorGravatar Reed Wanderman-Milne <reedwm@google.com>2017-09-27 12:58:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-27 13:04:57 -0700
commit759690f026a1a08b3ac5cc84d8498c05c32b2a7d (patch)
tree9c7ba12fef51b97226f4e0a07b9aa0eff7fccff1 /tensorflow/core/kernels/fused_batch_norm_op.cc
parent20370104cd8adf4c3f9068dfe95bde54cccadfa5 (diff)
Add float16 support to tf.nn.fused_batch_norm on the GPU.
Scale, offset, mean, and variance must still be float32 if the input is float16. PiperOrigin-RevId: 170239448
Diffstat (limited to 'tensorflow/core/kernels/fused_batch_norm_op.cc')
-rw-r--r--tensorflow/core/kernels/fused_batch_norm_op.cc179
1 files changed, 114 insertions, 65 deletions
diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cc b/tensorflow/core/kernels/fused_batch_norm_op.cc
index 92b093eec6..0ecb829f34 100644
--- a/tensorflow/core/kernels/fused_batch_norm_op.cc
+++ b/tensorflow/core/kernels/fused_batch_norm_op.cc
@@ -37,23 +37,28 @@ using GPUDevice = Eigen::GpuDevice;
namespace functor {
// Functor used by FusedBatchNormOp to do the computations.
-template <typename Device, typename T>
+template <typename Device, typename T, typename U>
struct FusedBatchNorm;
// Functor used by FusedBatchNormGradOp to do the computations when
// is_training=True.
-template <typename Device, typename T>
+template <typename Device, typename T, typename U>
struct FusedBatchNormGrad;
-template <typename T>
-struct FusedBatchNorm<CPUDevice, T> {
+template <typename T, typename U>
+struct FusedBatchNorm<CPUDevice, T, U> {
void operator()(OpKernelContext* context, const Tensor& x_input,
const Tensor& scale_input, const Tensor& offset_input,
const Tensor& estimated_mean_input,
- const Tensor& estimated_variance_input, T epsilon,
+ const Tensor& estimated_variance_input, U epsilon,
Tensor* y_output, Tensor* batch_mean_output,
Tensor* batch_var_output, Tensor* saved_mean_output,
Tensor* saved_var_output, TensorFormat tensor_format,
bool is_training) {
+ // Currently U is ignored, since we only support the case where T and U are
+ // both float32.
+ // TODO(reedwm): Add float16 support, use U, and remove these asserts.
+ static_assert(std::is_same<T, float>::value, "T currently must be float.");
+ static_assert(std::is_same<U, float>::value, "U currently must be float.");
OP_REQUIRES(context, tensor_format == FORMAT_NHWC,
errors::Internal("The CPU implementation of FusedBatchNorm "
"only supports NHWC tensor format for now."));
@@ -128,8 +133,8 @@ struct FusedBatchNorm<CPUDevice, T> {
}
};
-template <typename T>
-struct FusedBatchNormGrad<CPUDevice, T> {
+template <typename T, typename U>
+struct FusedBatchNormGrad<CPUDevice, T, U> {
void operator()(OpKernelContext* context, const Tensor& y_backprop_input,
const Tensor& x_input, const Tensor& scale_input,
const Tensor& mean_input, const Tensor& variance_input,
@@ -214,12 +219,12 @@ struct FusedBatchNormGrad<CPUDevice, T> {
};
#if GOOGLE_CUDA
-template <typename T>
-struct FusedBatchNorm<GPUDevice, T> {
+template <typename T, typename U>
+struct FusedBatchNorm<GPUDevice, T, U> {
void operator()(OpKernelContext* context, const Tensor& x,
const Tensor& scale, const Tensor& offset,
const Tensor& estimated_mean,
- const Tensor& estimated_variance, T epsilon, Tensor* y,
+ const Tensor& estimated_variance, U epsilon, Tensor* y,
Tensor* batch_mean, Tensor* batch_var, Tensor* saved_mean,
Tensor* saved_inv_var, TensorFormat tensor_format,
bool is_training) {
@@ -284,44 +289,44 @@ struct FusedBatchNorm<GPUDevice, T> {
.set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
auto x_ptr = StreamExecutorUtil::AsDeviceMemory<T>(x_maybe_transformed);
- auto scale_ptr = StreamExecutorUtil::AsDeviceMemory<T>(scale);
- auto offset_ptr = StreamExecutorUtil::AsDeviceMemory<T>(offset);
+ auto scale_ptr = StreamExecutorUtil::AsDeviceMemory<U>(scale);
+ auto offset_ptr = StreamExecutorUtil::AsDeviceMemory<U>(offset);
auto estimated_mean_ptr =
- StreamExecutorUtil::AsDeviceMemory<T>(estimated_mean);
+ StreamExecutorUtil::AsDeviceMemory<U>(estimated_mean);
auto estimated_variance_ptr =
- StreamExecutorUtil::AsDeviceMemory<T>(estimated_variance);
- auto batch_mean_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*batch_mean);
+ StreamExecutorUtil::AsDeviceMemory<U>(estimated_variance);
+ auto batch_mean_ptr = StreamExecutorUtil::AsDeviceMemory<U>(*batch_mean);
- auto batch_var_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*batch_var);
- auto saved_mean_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*saved_mean);
+ auto batch_var_ptr = StreamExecutorUtil::AsDeviceMemory<U>(*batch_var);
+ auto saved_mean_ptr = StreamExecutorUtil::AsDeviceMemory<U>(*saved_mean);
auto saved_inv_var_ptr =
- StreamExecutorUtil::AsDeviceMemory<T>(*saved_inv_var);
+ StreamExecutorUtil::AsDeviceMemory<U>(*saved_inv_var);
GPUDevice d = context->eigen_device<GPUDevice>();
using perftools::gputools::DeviceMemory;
Tensor inv_var;
OP_REQUIRES_OK(
- context, context->allocate_temp(DataTypeToEnum<T>::value,
+ context, context->allocate_temp(DataTypeToEnum<U>::value,
estimated_variance.shape(), &inv_var));
- auto inv_var_ptr = StreamExecutorUtil::AsDeviceMemory<T>(inv_var);
- std::function<const DeviceMemory<T>&()> var_to_inv_var =
+ auto inv_var_ptr = StreamExecutorUtil::AsDeviceMemory<U>(inv_var);
+ std::function<const DeviceMemory<U>&()> var_to_inv_var =
[d, epsilon, estimated_variance,
- &inv_var_ptr]() -> const DeviceMemory<T>& {
+ &inv_var_ptr]() -> const DeviceMemory<U>& {
auto estimated_variance_ptr =
- StreamExecutorUtil::AsDeviceMemory<T>(estimated_variance);
- const T* variance =
- static_cast<const T*>(estimated_variance_ptr.opaque());
- T* inv_variance = static_cast<T*>(inv_var_ptr.opaque());
+ StreamExecutorUtil::AsDeviceMemory<U>(estimated_variance);
+ const U* variance =
+ static_cast<const U*>(estimated_variance_ptr.opaque());
+ U* inv_variance = static_cast<U*>(inv_var_ptr.opaque());
int channels = inv_var_ptr.ElementCount();
- VarianceToInvVariance<T>()(d, variance, epsilon, channels, inv_variance);
+ VarianceToInvVariance<U>()(d, variance, epsilon, channels, inv_variance);
return inv_var_ptr;
};
const int64 sample_size = batch_size * height * width;
std::function<void()> inv_var_to_var = [d, &batch_var_ptr, epsilon,
sample_size]() {
- T* variance = static_cast<T*>(batch_var_ptr.opaque());
+ U* variance = static_cast<U*>(batch_var_ptr.opaque());
int channels = batch_var_ptr.ElementCount();
- InvVarianceToVariance<T>()(d, epsilon, sample_size, channels, variance);
+ InvVarianceToVariance<U>()(d, epsilon, sample_size, channels, variance);
};
bool cudnn_launch_status =
@@ -349,11 +354,11 @@ struct FusedBatchNorm<GPUDevice, T> {
}
};
-template <typename T>
-struct FusedBatchNormGrad<GPUDevice, T> {
+template <typename T, typename U>
+struct FusedBatchNormGrad<GPUDevice, T, U> {
void operator()(OpKernelContext* context, const Tensor& y_backprop,
const Tensor& x, const Tensor& scale, const Tensor& mean,
- const Tensor& inv_variance, T epsilon, Tensor* x_backprop,
+ const Tensor& inv_variance, U epsilon, Tensor* x_backprop,
Tensor* scale_backprop, Tensor* offset_backprop,
TensorFormat tensor_format) {
auto* stream = context->op_device_context()->stream();
@@ -440,13 +445,13 @@ struct FusedBatchNormGrad<GPUDevice, T> {
auto y_backprop_ptr =
StreamExecutorUtil::AsDeviceMemory<T>(y_backprop_maybe_transformed);
auto x_ptr = StreamExecutorUtil::AsDeviceMemory<T>(x_maybe_transformed);
- auto scale_ptr = StreamExecutorUtil::AsDeviceMemory<T>(scale);
- auto mean_ptr = StreamExecutorUtil::AsDeviceMemory<T>(mean);
- auto inv_variance_ptr = StreamExecutorUtil::AsDeviceMemory<T>(inv_variance);
+ auto scale_ptr = StreamExecutorUtil::AsDeviceMemory<U>(scale);
+ auto mean_ptr = StreamExecutorUtil::AsDeviceMemory<U>(mean);
+ auto inv_variance_ptr = StreamExecutorUtil::AsDeviceMemory<U>(inv_variance);
auto scale_backprop_ptr =
- StreamExecutorUtil::AsDeviceMemory<T>(*scale_backprop);
+ StreamExecutorUtil::AsDeviceMemory<U>(*scale_backprop);
auto offset_backprop_ptr =
- StreamExecutorUtil::AsDeviceMemory<T>(*offset_backprop);
+ StreamExecutorUtil::AsDeviceMemory<U>(*offset_backprop);
// the cudnn kernel outputs inverse variance in forward and reuse it in
// backward
@@ -473,28 +478,29 @@ struct FusedBatchNormGrad<GPUDevice, T> {
};
// Forward declarations of the functor specializations for GPU.
-#define DECLARE_GPU_SPEC(T) \
+#define DECLARE_GPU_SPEC(T, U) \
template <> \
- void FusedBatchNormFreezeGrad<GPUDevice, T>::operator()( \
+ void FusedBatchNormFreezeGrad<GPUDevice, T, U>::operator()( \
const GPUDevice& d, const Tensor& y_backprop_input, \
const Tensor& x_input, const Tensor& scale_input, \
- const Tensor& mean_input, const Tensor& variance_input, T epsilon, \
+ const Tensor& mean_input, const Tensor& variance_input, U epsilon, \
Tensor* x_backprop_output, Tensor* scale_backprop_output, \
- Tensor* offset_backprop_output, typename TTypes<T>::Vec scratch1, \
- typename TTypes<T>::Vec scratch2); \
- extern template struct FusedBatchNormFreezeGrad<GPUDevice, T>;
-DECLARE_GPU_SPEC(float);
+ Tensor* offset_backprop_output, typename TTypes<U>::Vec scratch1, \
+ typename TTypes<U>::Vec scratch2); \
+ extern template struct FusedBatchNormFreezeGrad<GPUDevice, T, U>;
+DECLARE_GPU_SPEC(float, float);
+DECLARE_GPU_SPEC(Eigen::half, float);
#endif // GOOGLE_CUDA
} // namespace functor
-template <typename Device, typename T>
+template <typename Device, typename T, typename U>
class FusedBatchNormOp : public OpKernel {
public:
explicit FusedBatchNormOp(OpKernelConstruction* context) : OpKernel(context) {
float epsilon;
OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon));
- epsilon_ = T(epsilon);
+ epsilon_ = U(epsilon);
string tensor_format;
OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format));
OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_),
@@ -552,26 +558,26 @@ class FusedBatchNormOp : public OpKernel {
OP_REQUIRES_OK(context, context->allocate_output(4, scale.shape(),
&saved_maybe_inv_var));
- functor::FusedBatchNorm<Device, T>()(
+ functor::FusedBatchNorm<Device, T, U>()(
context, x, scale, offset, estimated_mean, estimated_variance, epsilon_,
y, batch_mean, batch_var, saved_mean, saved_maybe_inv_var,
tensor_format_, is_training_);
}
private:
- T epsilon_;
+ U epsilon_;
TensorFormat tensor_format_;
bool is_training_;
};
-template <typename Device, typename T>
+template <typename Device, typename T, typename U>
class FusedBatchNormGradOp : public OpKernel {
public:
explicit FusedBatchNormGradOp(OpKernelConstruction* context)
: OpKernel(context) {
float epsilon;
OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon));
- epsilon_ = T(epsilon);
+ epsilon_ = U(epsilon);
string tensor_format;
OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format));
OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_),
@@ -631,7 +637,7 @@ class FusedBatchNormGradOp : public OpKernel {
context, context->allocate_output(4, TensorShape({}), &placeholder_2));
if (is_training_) {
- functor::FusedBatchNormGrad<Device, T>()(
+ functor::FusedBatchNormGrad<Device, T, U>()(
context, y_backprop, x, scale, saved_mean_or_pop_mean,
saved_maybe_inv_var_or_pop_var, epsilon_, x_backprop, scale_backprop,
offset_backprop, tensor_format_);
@@ -644,36 +650,79 @@ class FusedBatchNormGradOp : public OpKernel {
<< "NHWC tensor format for now.";
Tensor scratch1, scratch2;
OP_REQUIRES_OK(context,
- context->allocate_temp(DataTypeToEnum<T>::value,
+ context->allocate_temp(DataTypeToEnum<U>::value,
scale_offset_shape, &scratch1));
OP_REQUIRES_OK(context,
- context->allocate_temp(DataTypeToEnum<T>::value,
+ context->allocate_temp(DataTypeToEnum<U>::value,
scale_offset_shape, &scratch2));
- functor::FusedBatchNormFreezeGrad<Device, T>()(
+ functor::FusedBatchNormFreezeGrad<Device, T, U>()(
context->eigen_device<Device>(), y_backprop, x, scale,
saved_mean_or_pop_mean, saved_maybe_inv_var_or_pop_var, epsilon_,
- x_backprop, scale_backprop, offset_backprop, scratch1.vec<T>(),
- scratch2.vec<T>());
+ x_backprop, scale_backprop, offset_backprop, scratch1.vec<U>(),
+ scratch2.vec<U>());
}
}
private:
- T epsilon_;
+ U epsilon_;
TensorFormat tensor_format_;
bool is_training_;
};
-REGISTER_KERNEL_BUILDER(Name("FusedBatchNorm").Device(DEVICE_CPU),
- FusedBatchNormOp<CPUDevice, float>);
+REGISTER_KERNEL_BUILDER(
+ Name("FusedBatchNorm").Device(DEVICE_CPU).TypeConstraint<float>("T"),
+ FusedBatchNormOp<CPUDevice, float, float>);
+
+REGISTER_KERNEL_BUILDER(
+ Name("FusedBatchNormGrad").Device(DEVICE_CPU).TypeConstraint<float>("T"),
+ FusedBatchNormGradOp<CPUDevice, float, float>);
+
+REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV2")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<float>("T")
+ .TypeConstraint<float>("U"),
+ FusedBatchNormOp<CPUDevice, float, float>);
+
+REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV2")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<float>("T")
+ .TypeConstraint<float>("U"),
+ FusedBatchNormGradOp<CPUDevice, float, float>);
-REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGrad").Device(DEVICE_CPU),
- FusedBatchNormGradOp<CPUDevice, float>);
#if GOOGLE_CUDA
-REGISTER_KERNEL_BUILDER(Name("FusedBatchNorm").Device(DEVICE_GPU),
- FusedBatchNormOp<GPUDevice, float>);
-REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGrad").Device(DEVICE_GPU),
- FusedBatchNormGradOp<GPUDevice, float>);
+REGISTER_KERNEL_BUILDER(
+ Name("FusedBatchNorm").Device(DEVICE_GPU).TypeConstraint<float>("T"),
+ FusedBatchNormOp<GPUDevice, float, float>);
+
+REGISTER_KERNEL_BUILDER(
+ Name("FusedBatchNormGrad").Device(DEVICE_GPU).TypeConstraint<float>("T"),
+ FusedBatchNormGradOp<GPUDevice, float, float>);
+
+REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV2")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<float>("T")
+ .TypeConstraint<float>("U"),
+ FusedBatchNormOp<GPUDevice, float, float>);
+
+REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV2")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<float>("T")
+ .TypeConstraint<float>("U"),
+ FusedBatchNormGradOp<GPUDevice, float, float>);
+
+REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV2")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<Eigen::half>("T")
+ .TypeConstraint<float>("U"),
+ FusedBatchNormOp<GPUDevice, Eigen::half, float>);
+
+REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV2")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<Eigen::half>("T")
+ .TypeConstraint<float>("U"),
+ FusedBatchNormGradOp<GPUDevice, Eigen::half, float>);
+
#endif
} // namespace tensorflow