aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/layers/normalization.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/layers/normalization.py')
-rw-r--r--tensorflow/python/keras/layers/normalization.py6
1 files changed, 5 insertions, 1 deletions
diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py
index 58c8a8a66d..a7835bc0a2 100644
--- a/tensorflow/python/keras/layers/normalization.py
+++ b/tensorflow/python/keras/layers/normalization.py
@@ -370,7 +370,7 @@ class BatchNormalization(Layer):
decay = ops.convert_to_tensor(1.0 - momentum, name='decay')
if decay.dtype != variable.dtype.base_dtype:
decay = math_ops.cast(decay, variable.dtype.base_dtype)
- update_delta = (variable - value) * decay
+ update_delta = (variable - math_ops.cast(value, variable.dtype)) * decay
return state_ops.assign_sub(variable, update_delta, name=scope)
def _fused_batch_norm(self, inputs, training):
@@ -619,6 +619,10 @@ class BatchNormalization(Layer):
else:
mean, variance = self.moving_mean, self.moving_variance
+ mean = math_ops.cast(mean, inputs.dtype)
+ variance = math_ops.cast(variance, inputs.dtype)
+ if offset is not None:
+ offset = math_ops.cast(offset, inputs.dtype)
outputs = nn.batch_normalization(inputs,
_broadcast(mean),
_broadcast(variance),