diff options
author | Yuefeng Zhou <yuefengz@google.com> | 2018-03-19 17:04:28 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-19 17:09:02 -0700 |
commit | 4f5b7b42e2f8cb6b6e6730b6ada0edbee67dbfe3 (patch) | |
tree | eea6977f37f1727fa07e022fae4155441c52e441 /tensorflow/python/layers | |
parent | b1208ba0197547e75c3860b385d036e3909f8ea9 (diff) |
Fix test failure
PiperOrigin-RevId: 189666053
Diffstat (limited to 'tensorflow/python/layers')
-rw-r--r-- | tensorflow/python/layers/normalization.py | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py index 8b79a92cc4..11daf01670 100644 --- a/tensorflow/python/layers/normalization.py +++ b/tensorflow/python/layers/normalization.py @@ -364,8 +364,9 @@ class BatchNormalization(base.Layer): [variable, value, momentum]) as scope: with ops.colocate_with(variable): decay = ops.convert_to_tensor(1.0 - momentum, name='decay') - update_delta = math_ops.multiply( - math_ops.subtract(variable.read_value(), value), decay) + if decay.dtype != variable.dtype.base_dtype: + decay = math_ops.cast(decay, variable.dtype.base_dtype) + update_delta = (variable - value) * decay return state_ops.assign_sub(variable, update_delta, name=scope) def _fused_batch_norm(self, inputs, training): |