diff options
author | Yuefeng Zhou <yuefengz@google.com> | 2018-03-16 15:36:15 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-16 15:41:03 -0700 |
commit | 8cabc19c6b986d86c3e2b2d8d40f49a9400f926b (patch) | |
tree | 7b99f8d704f278a253bfbd012951b8fabd43f8a2 /tensorflow/python/layers | |
parent | 8e974f13f62cb284346f32744dc42c411a520f00 (diff) |
Consolidate all moving_average updates in batchnorm into one implementation.
PiperOrigin-RevId: 189404070
Diffstat (limited to 'tensorflow/python/layers')
-rw-r--r-- | tensorflow/python/layers/normalization.py | 50 |
1 files changed, 20 insertions, 30 deletions
diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py index c23d755a8e..8b79a92cc4 100644 --- a/tensorflow/python/layers/normalization.py +++ b/tensorflow/python/layers/normalization.py @@ -319,7 +319,6 @@ class BatchNormalization(base.Layer): initializer=self.moving_variance_initializer, trainable=False) - self._one_minus_decay = 1.0 - self.momentum if self.renorm: # Create variables to maintain the moving mean and standard deviation. # These are used in training and thus are different from the moving @@ -360,20 +359,14 @@ class BatchNormalization(base.Layer): self._scope.set_partitioner(partitioner) self.built = True - def _assign_moving_average(self, variable, value, one_minus_decay): + def _assign_moving_average(self, variable, value, momentum): with ops.name_scope(None, 'AssignMovingAvg', - [variable, value, one_minus_decay]) as scope: + [variable, value, momentum]) as scope: with ops.colocate_with(variable): + decay = ops.convert_to_tensor(1.0 - momentum, name='decay') update_delta = math_ops.multiply( - math_ops.subtract(variable.read_value(), value), - one_minus_decay) - if isinstance(variable, resource_variable_ops.ResourceVariable): - # state_ops.assign_sub does an extra read_variable_op after the - # assign. We avoid that here. - return gen_resource_variable_ops.assign_sub_variable_op( - variable.handle, update_delta, name=scope) - else: - return state_ops.assign_sub(variable, update_delta, name=scope) + math_ops.subtract(variable.read_value(), value), decay) + return state_ops.assign_sub(variable, update_delta, name=scope) def _fused_batch_norm(self, inputs, training): """Returns the output of fused batch norm.""" @@ -412,22 +405,16 @@ class BatchNormalization(base.Layer): training_value = utils.constant_value(training) if training_value is None: - one_minus_decay = utils.smart_cond(training, - lambda: self._one_minus_decay, - lambda: 0.) + momentum = utils.smart_cond(training, lambda: self.momentum, lambda: 1.0) else: - one_minus_decay = ops.convert_to_tensor(self._one_minus_decay) + momentum = ops.convert_to_tensor(self.momentum) if training_value or training_value is None: mean_update = self._assign_moving_average(self.moving_mean, mean, - one_minus_decay) + momentum) variance_update = self._assign_moving_average(self.moving_variance, - variance, one_minus_decay) - if not context.executing_eagerly(): - # Note that in Eager mode, the updates are already executed when running - # assign_moving_averages. So we do not need to put them into - # collections. - self.add_update(mean_update, inputs=inputs) - self.add_update(variance_update, inputs=inputs) + variance, momentum) + self.add_update(mean_update, inputs=inputs) + self.add_update(variance_update, inputs=inputs) return output @@ -464,6 +451,7 @@ class BatchNormalization(base.Layer): """Updates a moving average and weight, returns the unbiased value.""" value = array_ops.identity(value) def _do_update(): + """Updates the var and weight, returns their updated ratio.""" # Update the variables without zero debiasing. The debiasing will be # accomplished by dividing the exponential moving average by the weight. # For example, after a single update, the moving average would be @@ -472,11 +460,14 @@ class BatchNormalization(base.Layer): # Make sure the weight is not updated until before r and d computation. with ops.control_dependencies([value]): weight_value = array_ops.constant(1., dtype=weight.dtype) - new_var = moving_averages.assign_moving_average( - var, value, self.renorm_momentum, zero_debias=False) - new_weight = moving_averages.assign_moving_average( - weight, weight_value, self.renorm_momentum, zero_debias=False) + new_var = self._assign_moving_average(var, value, self.renorm_momentum) + new_weight = self._assign_moving_average(weight, weight_value, + self.renorm_momentum) + # TODO(yuefengz): the updates to var and weighted can not be batched + # together if we fetch their updated values here. Consider calculating + # new values and delaying the updates. return new_var / new_weight + def _fake_update(): return array_ops.identity(var) return utils.smart_cond(training, _do_update, _fake_update) @@ -601,8 +592,7 @@ class BatchNormalization(base.Layer): if in_eager_mode and not self.trainable: return - return moving_averages.assign_moving_average( - var, value, self.momentum, zero_debias=False) + return self._assign_moving_average(var, value, self.momentum) mean_update = utils.smart_cond( training, |