aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Chris Ying <chrisying@google.com>2018-05-30 17:38:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-30 17:40:26 -0700
commit2a484497062677f5cf0205ee3b9c28a64f03fe04 (patch)
treee9c4089ffdac64eb993106b7bb46988d495913ad
parent49535c9da686ea24f4e755e90fdaaa97f9f91b9d (diff)
Fix bug with renorm + virtual_batch_size.
PiperOrigin-RevId: 198648273
-rw-r--r--tensorflow/python/keras/layers/normalization.py26
1 files changed, 12 insertions, 14 deletions
diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py
index c0dc5220f1..7743d00c0f 100644
--- a/tensorflow/python/keras/layers/normalization.py
+++ b/tensorflow/python/keras/layers/normalization.py
@@ -574,28 +574,26 @@ class BatchNormalization(Layer):
lambda: variance,
lambda: moving_variance)
+ if self.virtual_batch_size is not None:
+ # This isn't strictly correct since in ghost batch norm, you are
+ # supposed to sequentially update the moving_mean and moving_variance
+ # with each sub-batch. However, since the moving statistics are only
+ # used during evaluation, it is more efficient to just update in one
+ # step and should not make a significant difference in the result.
+ new_mean = math_ops.reduce_mean(mean, axis=1, keepdims=True)
+ new_variance = math_ops.reduce_mean(variance, axis=1, keepdims=True)
+ else:
+ new_mean, new_variance = mean, variance
+
if self.renorm:
r, d, new_mean, new_variance = self._renorm_correction_and_moments(
- mean, variance, training)
+ new_mean, new_variance, training)
# When training, the normalized values (say, x) will be transformed as
# x * gamma + beta without renorm, and (x * r + d) * gamma + beta
# = x * (r * gamma) + (d * gamma + beta) with renorm.
r = _broadcast(array_ops.stop_gradient(r, name='renorm_r'))
d = _broadcast(array_ops.stop_gradient(d, name='renorm_d'))
scale, offset = _compose_transforms(r, d, scale, offset)
- else:
- new_mean, new_variance = mean, variance
-
- if self.virtual_batch_size is not None:
- # This isn't strictly correct since in ghost batch norm, you are
- # supposed to sequentially update the moving_mean and moving_variance
- # with each sub-batch. However, since the moving statistics are only
- # used during evaluation, it is more efficient to just update in one
- # step and should not make a significant difference in the result.
- new_mean = math_ops.reduce_mean(new_mean,
- axis=1, keepdims=True)
- new_variance = math_ops.reduce_mean(new_variance,
- axis=1, keepdims=True)
def _do_update(var, value):
if in_eager_mode and not self.trainable: