diff options
Diffstat (limited to 'tensorflow/contrib/layers/python/layers/layers.py')
-rw-r--r-- | tensorflow/contrib/layers/python/layers/layers.py | 19 |
1 files changed, 9 insertions, 10 deletions
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 46b3eeae91..f1debc8590 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -286,7 +286,6 @@ def _fused_batch_norm(inputs, ValueError: If the rank of `inputs` is neither 2 or 4. ValueError: If rank or `C` dimension of `inputs` is undefined. """ - # TODO(reedwm): Add support for fp16 inputs. if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC): raise ValueError('data_format has to be either NCHW or NHWC.') with variable_scope.variable_scope( @@ -310,7 +309,6 @@ def _fused_batch_norm(inputs, new_shape = [-1, channels, 1, 1] inputs = array_ops.reshape(inputs, new_shape) inputs_shape = inputs.get_shape() - dtype = inputs.dtype.base_dtype if data_format == DATA_FORMAT_NHWC: params_shape = inputs_shape[-1:] else: @@ -320,9 +318,10 @@ def _fused_batch_norm(inputs, (inputs.name, params_shape)) # Allocate parameters for the beta and gamma of the normalization. - trainable_beta = trainable and center beta_collections = utils.get_variable_collections(variables_collections, 'beta') + # Float32 required to avoid precision-loss when using fp16 input/output + variable_dtype = dtypes.float32 if not param_initializers: param_initializers = {} if not param_regularizers: @@ -336,13 +335,13 @@ def _fused_batch_norm(inputs, beta = variables.model_variable( 'beta', shape=params_shape, - dtype=dtype, + dtype=variable_dtype, initializer=beta_initializer, regularizer=beta_regularizer, collections=beta_collections, - trainable=trainable_beta) + trainable=trainable) else: - beta = array_ops.constant(0.0, shape=params_shape) + beta = array_ops.constant(0.0, dtype=variable_dtype, shape=params_shape) if scale: gamma_collections = utils.get_variable_collections( @@ -352,13 +351,13 @@ def _fused_batch_norm(inputs, gamma = variables.model_variable( 'gamma', shape=params_shape, - dtype=dtype, + dtype=variable_dtype, initializer=gamma_initializer, regularizer=gamma_regularizer, collections=gamma_collections, trainable=trainable) else: - gamma = array_ops.constant(1.0, shape=params_shape) + gamma = array_ops.constant(1.0, dtype=variable_dtype, shape=params_shape) # Create moving_mean and moving_variance variables and add them to the # appropriate collections. We disable variable partitioning while creating @@ -375,7 +374,7 @@ def _fused_batch_norm(inputs, moving_mean = variables.model_variable( 'moving_mean', shape=params_shape, - dtype=dtype, + dtype=variable_dtype, initializer=moving_mean_initializer, trainable=False, collections=moving_mean_collections) @@ -386,7 +385,7 @@ def _fused_batch_norm(inputs, moving_variance = variables.model_variable( 'moving_variance', shape=params_shape, - dtype=dtype, + dtype=variable_dtype, initializer=moving_variance_initializer, trainable=False, collections=moving_variance_collections) |