aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/layers
diff options
context:
space:
mode:
authorGravatar Yuefeng Zhou <yuefengz@google.com>2018-03-16 15:36:15 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-16 15:41:03 -0700
commit8cabc19c6b986d86c3e2b2d8d40f49a9400f926b (patch)
tree7b99f8d704f278a253bfbd012951b8fabd43f8a2 /tensorflow/python/layers
parent8e974f13f62cb284346f32744dc42c411a520f00 (diff)
Consolidate all moving_average updates in batchnorm into one implementation.
PiperOrigin-RevId: 189404070
Diffstat (limited to 'tensorflow/python/layers')
-rw-r--r--tensorflow/python/layers/normalization.py50
1 files changed, 20 insertions, 30 deletions
diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py
index c23d755a8e..8b79a92cc4 100644
--- a/tensorflow/python/layers/normalization.py
+++ b/tensorflow/python/layers/normalization.py
@@ -319,7 +319,6 @@ class BatchNormalization(base.Layer):
initializer=self.moving_variance_initializer,
trainable=False)
- self._one_minus_decay = 1.0 - self.momentum
if self.renorm:
# Create variables to maintain the moving mean and standard deviation.
# These are used in training and thus are different from the moving
@@ -360,20 +359,14 @@ class BatchNormalization(base.Layer):
self._scope.set_partitioner(partitioner)
self.built = True
- def _assign_moving_average(self, variable, value, one_minus_decay):
+ def _assign_moving_average(self, variable, value, momentum):
with ops.name_scope(None, 'AssignMovingAvg',
- [variable, value, one_minus_decay]) as scope:
+ [variable, value, momentum]) as scope:
with ops.colocate_with(variable):
+ decay = ops.convert_to_tensor(1.0 - momentum, name='decay')
update_delta = math_ops.multiply(
- math_ops.subtract(variable.read_value(), value),
- one_minus_decay)
- if isinstance(variable, resource_variable_ops.ResourceVariable):
- # state_ops.assign_sub does an extra read_variable_op after the
- # assign. We avoid that here.
- return gen_resource_variable_ops.assign_sub_variable_op(
- variable.handle, update_delta, name=scope)
- else:
- return state_ops.assign_sub(variable, update_delta, name=scope)
+ math_ops.subtract(variable.read_value(), value), decay)
+ return state_ops.assign_sub(variable, update_delta, name=scope)
def _fused_batch_norm(self, inputs, training):
"""Returns the output of fused batch norm."""
@@ -412,22 +405,16 @@ class BatchNormalization(base.Layer):
training_value = utils.constant_value(training)
if training_value is None:
- one_minus_decay = utils.smart_cond(training,
- lambda: self._one_minus_decay,
- lambda: 0.)
+ momentum = utils.smart_cond(training, lambda: self.momentum, lambda: 1.0)
else:
- one_minus_decay = ops.convert_to_tensor(self._one_minus_decay)
+ momentum = ops.convert_to_tensor(self.momentum)
if training_value or training_value is None:
mean_update = self._assign_moving_average(self.moving_mean, mean,
- one_minus_decay)
+ momentum)
variance_update = self._assign_moving_average(self.moving_variance,
- variance, one_minus_decay)
- if not context.executing_eagerly():
- # Note that in Eager mode, the updates are already executed when running
- # assign_moving_averages. So we do not need to put them into
- # collections.
- self.add_update(mean_update, inputs=inputs)
- self.add_update(variance_update, inputs=inputs)
+ variance, momentum)
+ self.add_update(mean_update, inputs=inputs)
+ self.add_update(variance_update, inputs=inputs)
return output
@@ -464,6 +451,7 @@ class BatchNormalization(base.Layer):
"""Updates a moving average and weight, returns the unbiased value."""
value = array_ops.identity(value)
def _do_update():
+ """Updates the var and weight, returns their updated ratio."""
# Update the variables without zero debiasing. The debiasing will be
# accomplished by dividing the exponential moving average by the weight.
# For example, after a single update, the moving average would be
@@ -472,11 +460,14 @@ class BatchNormalization(base.Layer):
# Make sure the weight is not updated until before r and d computation.
with ops.control_dependencies([value]):
weight_value = array_ops.constant(1., dtype=weight.dtype)
- new_var = moving_averages.assign_moving_average(
- var, value, self.renorm_momentum, zero_debias=False)
- new_weight = moving_averages.assign_moving_average(
- weight, weight_value, self.renorm_momentum, zero_debias=False)
+ new_var = self._assign_moving_average(var, value, self.renorm_momentum)
+ new_weight = self._assign_moving_average(weight, weight_value,
+ self.renorm_momentum)
+ # TODO(yuefengz): the updates to var and weighted can not be batched
+ # together if we fetch their updated values here. Consider calculating
+ # new values and delaying the updates.
return new_var / new_weight
+
def _fake_update():
return array_ops.identity(var)
return utils.smart_cond(training, _do_update, _fake_update)
@@ -601,8 +592,7 @@ class BatchNormalization(base.Layer):
if in_eager_mode and not self.trainable:
return
- return moving_averages.assign_moving_average(
- var, value, self.momentum, zero_debias=False)
+ return self._assign_moving_average(var, value, self.momentum)
mean_update = utils.smart_cond(
training,