aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/layers/normalization.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/layers/normalization.py')
-rw-r--r--tensorflow/python/layers/normalization.py55
1 files changed, 36 insertions, 19 deletions
diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py
index 9d9b2b3941..83237b8733 100644
--- a/tensorflow/python/layers/normalization.py
+++ b/tensorflow/python/layers/normalization.py
@@ -26,6 +26,7 @@ import numpy as np
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.layers import base
@@ -239,6 +240,12 @@ class BatchNormalization(base.Layer):
raise ValueError('Unsupported axis, fused batch norm only supports '
'axis == [1] or axis == [3]')
+ # Raise parameters of fp16 batch norm to fp32
+ if self.dtype == dtypes.float16:
+ param_dtype = dtypes.float32
+ else:
+ param_dtype = self.dtype or dtypes.float32
+
axis_to_dim = {x: input_shape[x].value for x in self.axis}
for x in axis_to_dim:
if axis_to_dim[x] is None:
@@ -260,28 +267,34 @@ class BatchNormalization(base.Layer):
self.axis[idx] = x + 1 # Account for added dimension
if self.scale:
- self.gamma = self.add_variable(name='gamma',
- shape=param_shape,
- initializer=self.gamma_initializer,
- regularizer=self.gamma_regularizer,
- constraint=self.gamma_constraint,
- trainable=True)
+ self.gamma = self.add_variable(
+ name='gamma',
+ shape=param_shape,
+ dtype=param_dtype,
+ initializer=self.gamma_initializer,
+ regularizer=self.gamma_regularizer,
+ constraint=self.gamma_constraint,
+ trainable=True)
else:
self.gamma = None
if self.fused:
- self._gamma_const = array_ops.constant(1.0, shape=param_shape)
+ self._gamma_const = array_ops.constant(
+ 1.0, dtype=param_dtype, shape=param_shape)
if self.center:
- self.beta = self.add_variable(name='beta',
- shape=param_shape,
- initializer=self.beta_initializer,
- regularizer=self.beta_regularizer,
- constraint=self.beta_constraint,
- trainable=True)
+ self.beta = self.add_variable(
+ name='beta',
+ shape=param_shape,
+ dtype=param_dtype,
+ initializer=self.beta_initializer,
+ regularizer=self.beta_regularizer,
+ constraint=self.beta_constraint,
+ trainable=True)
else:
self.beta = None
if self.fused:
- self._beta_const = array_ops.constant(0.0, shape=param_shape)
+ self._beta_const = array_ops.constant(
+ 0.0, dtype=param_dtype, shape=param_shape)
# Disable variable partitioning when creating the moving mean and variance
try:
@@ -293,12 +306,14 @@ class BatchNormalization(base.Layer):
self.moving_mean = self.add_variable(
name='moving_mean',
shape=param_shape,
+ dtype=param_dtype,
initializer=self.moving_mean_initializer,
trainable=False)
self.moving_variance = self.add_variable(
name='moving_variance',
shape=param_shape,
+ dtype=param_dtype,
initializer=self.moving_variance_initializer,
trainable=False)
@@ -312,10 +327,12 @@ class BatchNormalization(base.Layer):
# stack to be cleared. The nested ones use a `lambda` to set the desired
# device and ignore any devices that may be set by the custom getter.
def _renorm_variable(name, shape):
- var = self.add_variable(name=name,
- shape=shape,
- initializer=init_ops.zeros_initializer(),
- trainable=False)
+ var = self.add_variable(
+ name=name,
+ shape=shape,
+ dtype=param_dtype,
+ initializer=init_ops.zeros_initializer(),
+ trainable=False)
return var
with ops.device(None):
@@ -356,7 +373,6 @@ class BatchNormalization(base.Layer):
def _fused_batch_norm(self, inputs, training):
"""Returns the output of fused batch norm."""
- # TODO(reedwm): Add support for fp16 inputs.
beta = self.beta if self.center else self._beta_const
gamma = self.gamma if self.scale else self._gamma_const
@@ -752,6 +768,7 @@ def batch_normalization(inputs,
virtual_batch_size=virtual_batch_size,
adjustment=adjustment,
name=name,
+ dtype=inputs.dtype.base_dtype,
_reuse=reuse,
_scope=name)
return layer.apply(inputs, training=training)