diff options
Diffstat (limited to 'tensorflow/python/keras/layers/normalization.py')
-rw-r--r-- | tensorflow/python/keras/layers/normalization.py | 6 |
1 files changed, 5 insertions, 1 deletions
diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py index 58c8a8a66d..a7835bc0a2 100644 --- a/tensorflow/python/keras/layers/normalization.py +++ b/tensorflow/python/keras/layers/normalization.py @@ -370,7 +370,7 @@ class BatchNormalization(Layer): decay = ops.convert_to_tensor(1.0 - momentum, name='decay') if decay.dtype != variable.dtype.base_dtype: decay = math_ops.cast(decay, variable.dtype.base_dtype) - update_delta = (variable - value) * decay + update_delta = (variable - math_ops.cast(value, variable.dtype)) * decay return state_ops.assign_sub(variable, update_delta, name=scope) def _fused_batch_norm(self, inputs, training): @@ -619,6 +619,10 @@ class BatchNormalization(Layer): else: mean, variance = self.moving_mean, self.moving_variance + mean = math_ops.cast(mean, inputs.dtype) + variance = math_ops.cast(variance, inputs.dtype) + if offset is not None: + offset = math_ops.cast(offset, inputs.dtype) outputs = nn.batch_normalization(inputs, _broadcast(mean), _broadcast(variance), |