diff options
author | 2017-04-11 17:04:44 -0800 | |
---|---|---|
committer | 2017-04-11 18:31:42 -0700 | |
commit | 348b3b114021c06d90dcf0f0efd3f22642d64929 (patch) | |
tree | 2b8267e193c3388781c36fc0be8b2eafc4dcda11 | |
parent | 0aa67b90001f48c1004754036505b7e405243a28 (diff) |
Ensure that moving_mean and moving_variance are grabbed before being updated. Note that this should not be necessary since they are only updated during training and used during testing (unlike renorm_* which are both used and updated during training). However, due to an apparent bug in tf.cond, we create "fake" updates during inference (with decay=1), and so data races may result. This CL fixes that.
Change: 152887553
-rw-r--r-- | tensorflow/python/layers/normalization.py | 14 |
1 files changed, 7 insertions, 7 deletions
diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py index 5fd20259be..34b663119e 100644 --- a/tensorflow/python/layers/normalization.py +++ b/tensorflow/python/layers/normalization.py @@ -282,6 +282,13 @@ class BatchNormalization(base._Layer): # pylint: disable=protected-access # Some of the computations here are not necessary when training==False # but not a constant. However, this makes the code simpler. mean, variance = nn.moments(inputs, reduction_axes) + mean = _smart_select(training, + lambda: mean, + lambda: self.moving_mean) + variance = _smart_select(training, + lambda: variance, + lambda: self.moving_variance) + if self.renorm: r, d, new_mean, new_variance = self._renorm_correction_and_moments( mean, variance, training) @@ -312,13 +319,6 @@ class BatchNormalization(base._Layer): # pylint: disable=protected-access self.updates.append(mean_update) self.updates.append(variance_update) - mean = _smart_select(training, - lambda: mean, - lambda: self.moving_mean) - variance = _smart_select(training, - lambda: variance, - lambda: self.moving_variance) - else: mean, variance = self.moving_mean, self.moving_variance |