diff options
author | Chris Ying <chrisying@google.com> | 2018-05-30 17:38:13 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-30 17:40:26 -0700 |
commit | 2a484497062677f5cf0205ee3b9c28a64f03fe04 (patch) | |
tree | e9c4089ffdac64eb993106b7bb46988d495913ad | |
parent | 49535c9da686ea24f4e755e90fdaaa97f9f91b9d (diff) |
Fix bug with renorm + virtual_batch_size.
PiperOrigin-RevId: 198648273
-rw-r--r-- | tensorflow/python/keras/layers/normalization.py | 26 |
1 files changed, 12 insertions, 14 deletions
diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py index c0dc5220f1..7743d00c0f 100644 --- a/tensorflow/python/keras/layers/normalization.py +++ b/tensorflow/python/keras/layers/normalization.py @@ -574,28 +574,26 @@ class BatchNormalization(Layer): lambda: variance, lambda: moving_variance) + if self.virtual_batch_size is not None: + # This isn't strictly correct since in ghost batch norm, you are + # supposed to sequentially update the moving_mean and moving_variance + # with each sub-batch. However, since the moving statistics are only + # used during evaluation, it is more efficient to just update in one + # step and should not make a significant difference in the result. + new_mean = math_ops.reduce_mean(mean, axis=1, keepdims=True) + new_variance = math_ops.reduce_mean(variance, axis=1, keepdims=True) + else: + new_mean, new_variance = mean, variance + if self.renorm: r, d, new_mean, new_variance = self._renorm_correction_and_moments( - mean, variance, training) + new_mean, new_variance, training) # When training, the normalized values (say, x) will be transformed as # x * gamma + beta without renorm, and (x * r + d) * gamma + beta # = x * (r * gamma) + (d * gamma + beta) with renorm. r = _broadcast(array_ops.stop_gradient(r, name='renorm_r')) d = _broadcast(array_ops.stop_gradient(d, name='renorm_d')) scale, offset = _compose_transforms(r, d, scale, offset) - else: - new_mean, new_variance = mean, variance - - if self.virtual_batch_size is not None: - # This isn't strictly correct since in ghost batch norm, you are - # supposed to sequentially update the moving_mean and moving_variance - # with each sub-batch. However, since the moving statistics are only - # used during evaluation, it is more efficient to just update in one - # step and should not make a significant difference in the result. - new_mean = math_ops.reduce_mean(new_mean, - axis=1, keepdims=True) - new_variance = math_ops.reduce_mean(new_variance, - axis=1, keepdims=True) def _do_update(var, value): if in_eager_mode and not self.trainable: |