diff options
Diffstat (limited to 'tensorflow/python/layers/normalization.py')
-rw-r--r-- | tensorflow/python/layers/normalization.py | 55 |
1 files changed, 36 insertions, 19 deletions
diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py index 9d9b2b3941..83237b8733 100644 --- a/tensorflow/python/layers/normalization.py +++ b/tensorflow/python/layers/normalization.py @@ -26,6 +26,7 @@ import numpy as np from tensorflow.python.eager import context from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.layers import base @@ -239,6 +240,12 @@ class BatchNormalization(base.Layer): raise ValueError('Unsupported axis, fused batch norm only supports ' 'axis == [1] or axis == [3]') + # Raise parameters of fp16 batch norm to fp32 + if self.dtype == dtypes.float16: + param_dtype = dtypes.float32 + else: + param_dtype = self.dtype or dtypes.float32 + axis_to_dim = {x: input_shape[x].value for x in self.axis} for x in axis_to_dim: if axis_to_dim[x] is None: @@ -260,28 +267,34 @@ class BatchNormalization(base.Layer): self.axis[idx] = x + 1 # Account for added dimension if self.scale: - self.gamma = self.add_variable(name='gamma', - shape=param_shape, - initializer=self.gamma_initializer, - regularizer=self.gamma_regularizer, - constraint=self.gamma_constraint, - trainable=True) + self.gamma = self.add_variable( + name='gamma', + shape=param_shape, + dtype=param_dtype, + initializer=self.gamma_initializer, + regularizer=self.gamma_regularizer, + constraint=self.gamma_constraint, + trainable=True) else: self.gamma = None if self.fused: - self._gamma_const = array_ops.constant(1.0, shape=param_shape) + self._gamma_const = array_ops.constant( + 1.0, dtype=param_dtype, shape=param_shape) if self.center: - self.beta = self.add_variable(name='beta', - shape=param_shape, - initializer=self.beta_initializer, - regularizer=self.beta_regularizer, - constraint=self.beta_constraint, - trainable=True) + self.beta = self.add_variable( + name='beta', + shape=param_shape, + dtype=param_dtype, + initializer=self.beta_initializer, + regularizer=self.beta_regularizer, + constraint=self.beta_constraint, + trainable=True) else: self.beta = None if self.fused: - self._beta_const = array_ops.constant(0.0, shape=param_shape) + self._beta_const = array_ops.constant( + 0.0, dtype=param_dtype, shape=param_shape) # Disable variable partitioning when creating the moving mean and variance try: @@ -293,12 +306,14 @@ class BatchNormalization(base.Layer): self.moving_mean = self.add_variable( name='moving_mean', shape=param_shape, + dtype=param_dtype, initializer=self.moving_mean_initializer, trainable=False) self.moving_variance = self.add_variable( name='moving_variance', shape=param_shape, + dtype=param_dtype, initializer=self.moving_variance_initializer, trainable=False) @@ -312,10 +327,12 @@ class BatchNormalization(base.Layer): # stack to be cleared. The nested ones use a `lambda` to set the desired # device and ignore any devices that may be set by the custom getter. def _renorm_variable(name, shape): - var = self.add_variable(name=name, - shape=shape, - initializer=init_ops.zeros_initializer(), - trainable=False) + var = self.add_variable( + name=name, + shape=shape, + dtype=param_dtype, + initializer=init_ops.zeros_initializer(), + trainable=False) return var with ops.device(None): @@ -356,7 +373,6 @@ class BatchNormalization(base.Layer): def _fused_batch_norm(self, inputs, training): """Returns the output of fused batch norm.""" - # TODO(reedwm): Add support for fp16 inputs. beta = self.beta if self.center else self._beta_const gamma = self.gamma if self.scale else self._gamma_const @@ -752,6 +768,7 @@ def batch_normalization(inputs, virtual_batch_size=virtual_batch_size, adjustment=adjustment, name=name, + dtype=inputs.dtype.base_dtype, _reuse=reuse, _scope=name) return layer.apply(inputs, training=training) |