aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-04-11 17:04:44 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-11 18:31:42 -0700
commit348b3b114021c06d90dcf0f0efd3f22642d64929 (patch)
tree2b8267e193c3388781c36fc0be8b2eafc4dcda11
parent0aa67b90001f48c1004754036505b7e405243a28 (diff)
Ensure that moving_mean and moving_variance are grabbed before being updated. Note that this should not be necessary since they are only updated during training and used during testing (unlike renorm_* which are both used and updated during training). However, due to an apparent bug in tf.cond, we create "fake" updates during inference (with decay=1), and so data races may result. This CL fixes that.
Change: 152887553
-rw-r--r--tensorflow/python/layers/normalization.py14
1 files changed, 7 insertions, 7 deletions
diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py
index 5fd20259be..34b663119e 100644
--- a/tensorflow/python/layers/normalization.py
+++ b/tensorflow/python/layers/normalization.py
@@ -282,6 +282,13 @@ class BatchNormalization(base._Layer): # pylint: disable=protected-access
# Some of the computations here are not necessary when training==False
# but not a constant. However, this makes the code simpler.
mean, variance = nn.moments(inputs, reduction_axes)
+ mean = _smart_select(training,
+ lambda: mean,
+ lambda: self.moving_mean)
+ variance = _smart_select(training,
+ lambda: variance,
+ lambda: self.moving_variance)
+
if self.renorm:
r, d, new_mean, new_variance = self._renorm_correction_and_moments(
mean, variance, training)
@@ -312,13 +319,6 @@ class BatchNormalization(base._Layer): # pylint: disable=protected-access
self.updates.append(mean_update)
self.updates.append(variance_update)
- mean = _smart_select(training,
- lambda: mean,
- lambda: self.moving_mean)
- variance = _smart_select(training,
- lambda: variance,
- lambda: self.moving_variance)
-
else:
mean, variance = self.moving_mean, self.moving_variance