aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/layers
diff options
context:
space:
mode:
authorGravatar Yuefeng Zhou <yuefengz@google.com>2018-03-19 17:04:28 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-19 17:09:02 -0700
commit4f5b7b42e2f8cb6b6e6730b6ada0edbee67dbfe3 (patch)
treeeea6977f37f1727fa07e022fae4155441c52e441 /tensorflow/python/layers
parentb1208ba0197547e75c3860b385d036e3909f8ea9 (diff)
Fix test failure
PiperOrigin-RevId: 189666053
Diffstat (limited to 'tensorflow/python/layers')
-rw-r--r--tensorflow/python/layers/normalization.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py
index 8b79a92cc4..11daf01670 100644
--- a/tensorflow/python/layers/normalization.py
+++ b/tensorflow/python/layers/normalization.py
@@ -364,8 +364,9 @@ class BatchNormalization(base.Layer):
[variable, value, momentum]) as scope:
with ops.colocate_with(variable):
decay = ops.convert_to_tensor(1.0 - momentum, name='decay')
- update_delta = math_ops.multiply(
- math_ops.subtract(variable.read_value(), value), decay)
+ if decay.dtype != variable.dtype.base_dtype:
+ decay = math_ops.cast(decay, variable.dtype.base_dtype)
+ update_delta = (variable - value) * decay
return state_ops.assign_sub(variable, update_delta, name=scope)
def _fused_batch_norm(self, inputs, training):