aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/layers/python/layers/layers.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/layers/python/layers/layers.py')
-rw-r--r--tensorflow/contrib/layers/python/layers/layers.py19
1 files changed, 9 insertions, 10 deletions
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py
index 46b3eeae91..f1debc8590 100644
--- a/tensorflow/contrib/layers/python/layers/layers.py
+++ b/tensorflow/contrib/layers/python/layers/layers.py
@@ -286,7 +286,6 @@ def _fused_batch_norm(inputs,
ValueError: If the rank of `inputs` is neither 2 or 4.
ValueError: If rank or `C` dimension of `inputs` is undefined.
"""
- # TODO(reedwm): Add support for fp16 inputs.
if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC):
raise ValueError('data_format has to be either NCHW or NHWC.')
with variable_scope.variable_scope(
@@ -310,7 +309,6 @@ def _fused_batch_norm(inputs,
new_shape = [-1, channels, 1, 1]
inputs = array_ops.reshape(inputs, new_shape)
inputs_shape = inputs.get_shape()
- dtype = inputs.dtype.base_dtype
if data_format == DATA_FORMAT_NHWC:
params_shape = inputs_shape[-1:]
else:
@@ -320,9 +318,10 @@ def _fused_batch_norm(inputs,
(inputs.name, params_shape))
# Allocate parameters for the beta and gamma of the normalization.
- trainable_beta = trainable and center
beta_collections = utils.get_variable_collections(variables_collections,
'beta')
+ # Float32 required to avoid precision-loss when using fp16 input/output
+ variable_dtype = dtypes.float32
if not param_initializers:
param_initializers = {}
if not param_regularizers:
@@ -336,13 +335,13 @@ def _fused_batch_norm(inputs,
beta = variables.model_variable(
'beta',
shape=params_shape,
- dtype=dtype,
+ dtype=variable_dtype,
initializer=beta_initializer,
regularizer=beta_regularizer,
collections=beta_collections,
- trainable=trainable_beta)
+ trainable=trainable)
else:
- beta = array_ops.constant(0.0, shape=params_shape)
+ beta = array_ops.constant(0.0, dtype=variable_dtype, shape=params_shape)
if scale:
gamma_collections = utils.get_variable_collections(
@@ -352,13 +351,13 @@ def _fused_batch_norm(inputs,
gamma = variables.model_variable(
'gamma',
shape=params_shape,
- dtype=dtype,
+ dtype=variable_dtype,
initializer=gamma_initializer,
regularizer=gamma_regularizer,
collections=gamma_collections,
trainable=trainable)
else:
- gamma = array_ops.constant(1.0, shape=params_shape)
+ gamma = array_ops.constant(1.0, dtype=variable_dtype, shape=params_shape)
# Create moving_mean and moving_variance variables and add them to the
# appropriate collections. We disable variable partitioning while creating
@@ -375,7 +374,7 @@ def _fused_batch_norm(inputs,
moving_mean = variables.model_variable(
'moving_mean',
shape=params_shape,
- dtype=dtype,
+ dtype=variable_dtype,
initializer=moving_mean_initializer,
trainable=False,
collections=moving_mean_collections)
@@ -386,7 +385,7 @@ def _fused_batch_norm(inputs,
moving_variance = variables.model_variable(
'moving_variance',
shape=params_shape,
- dtype=dtype,
+ dtype=variable_dtype,
initializer=moving_variance_initializer,
trainable=False,
collections=moving_variance_collections)