diff options
author | Yuefeng Zhou <yuefengz@google.com> | 2018-03-29 13:22:45 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-29 13:25:02 -0700 |
commit | eb2be37c12ae2b6c996f3f4c064e3d10f9565eab (patch) | |
tree | 0cfdd4f1654b66202754601b510b0a28304a1d2c /tensorflow/python/layers | |
parent | a259ba951d3af9f62a0f95a881abf9ebaa45782b (diff) |
Internal change.
PiperOrigin-RevId: 190976338
Diffstat (limited to 'tensorflow/python/layers')
-rw-r--r-- | tensorflow/python/layers/normalization.py | 76 |
1 files changed, 36 insertions, 40 deletions
diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py index 29fb92ccb5..83b201e642 100644 --- a/tensorflow/python/layers/normalization.py +++ b/tensorflow/python/layers/normalization.py @@ -32,12 +32,12 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.layers import base from tensorflow.python.layers import utils from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn -from tensorflow.python.ops import gen_resource_variable_ops from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import init_ops from tensorflow.python.ops import state_ops +from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import moving_averages from tensorflow.python.util.tf_export import tf_export @@ -178,6 +178,11 @@ class BatchNormalization(base.Layer): self.renorm_clipping = renorm_clipping self.renorm_momentum = renorm_momentum + def _add_tower_local_variable(self, *args, **kwargs): + tower_context = distribute_lib.get_tower_context() + with tower_context.tower_local_var_scope('mean'): + return self.add_variable(*args, **kwargs) + def build(self, input_shape): input_shape = tensor_shape.TensorShape(input_shape) if not input_shape.ndims: @@ -305,14 +310,14 @@ class BatchNormalization(base.Layer): self._scope.set_partitioner(None) else: partitioner = None - self.moving_mean = self.add_variable( + self.moving_mean = self._add_tower_local_variable( name='moving_mean', shape=param_shape, dtype=param_dtype, initializer=self.moving_mean_initializer, trainable=False) - self.moving_variance = self.add_variable( + self.moving_variance = self._add_tower_local_variable( name='moving_variance', shape=param_shape, dtype=param_dtype, @@ -328,7 +333,7 @@ class BatchNormalization(base.Layer): # stack to be cleared. The nested ones use a `lambda` to set the desired # device and ignore any devices that may be set by the custom getter. def _renorm_variable(name, shape): - var = self.add_variable( + var = self._add_tower_local_variable( name=name, shape=shape, dtype=param_dtype, @@ -336,24 +341,19 @@ class BatchNormalization(base.Layer): trainable=False) return var - with ops.device(None): - device = ( - self.moving_mean.device if context.executing_eagerly() else - (lambda _: self.moving_mean.device)) - with ops.device(device): - self.renorm_mean = _renorm_variable('renorm_mean', param_shape) - self.renorm_mean_weight = _renorm_variable('renorm_mean_weight', ()) - # We initialize renorm_stddev to 0, and maintain the (0-initialized) - # renorm_stddev_weight. This allows us to (1) mix the average - # stddev with the minibatch stddev early in training, and (2) compute - # the unbiased average stddev by dividing renorm_stddev by the weight. - device = ( - self.moving_variance.device if context.executing_eagerly() else - (lambda _: self.moving_variance.device)) - with ops.device(device): - self.renorm_stddev = _renorm_variable('renorm_stddev', param_shape) - self.renorm_stddev_weight = _renorm_variable( - 'renorm_stddev_weight', ()) + with distribute_lib.get_distribution_strategy().colocate_vars_with( + self.moving_mean): + self.renorm_mean = _renorm_variable('renorm_mean', param_shape) + self.renorm_mean_weight = _renorm_variable('renorm_mean_weight', ()) + # We initialize renorm_stddev to 0, and maintain the (0-initialized) + # renorm_stddev_weight. This allows us to (1) mix the average + # stddev with the minibatch stddev early in training, and (2) compute + # the unbiased average stddev by dividing renorm_stddev by the weight. + with distribute_lib.get_distribution_strategy().colocate_vars_with( + self.moving_variance): + self.renorm_stddev = _renorm_variable('renorm_stddev', param_shape) + self.renorm_stddev_weight = _renorm_variable('renorm_stddev_weight', + ()) finally: if partitioner: self._scope.set_partitioner(partitioner) @@ -362,12 +362,11 @@ class BatchNormalization(base.Layer): def _assign_moving_average(self, variable, value, momentum): with ops.name_scope(None, 'AssignMovingAvg', [variable, value, momentum]) as scope: - with ops.colocate_with(variable): - decay = ops.convert_to_tensor(1.0 - momentum, name='decay') - if decay.dtype != variable.dtype.base_dtype: - decay = math_ops.cast(decay, variable.dtype.base_dtype) - update_delta = (variable - value) * decay - return state_ops.assign_sub(variable, update_delta, name=scope) + decay = ops.convert_to_tensor(1.0 - momentum, name='decay') + if decay.dtype != variable.dtype.base_dtype: + decay = math_ops.cast(decay, variable.dtype.base_dtype) + update_delta = (variable - value) * decay + return state_ops.assign_sub(variable, update_delta, name=scope) def _fused_batch_norm(self, inputs, training): """Returns the output of fused batch norm.""" @@ -473,16 +472,13 @@ class BatchNormalization(base.Layer): return array_ops.identity(var) return utils.smart_cond(training, _do_update, _fake_update) - with ops.colocate_with(self.moving_mean): - new_mean = _update_renorm_variable(self.renorm_mean, - self.renorm_mean_weight, - mean) - with ops.colocate_with(self.moving_variance): - new_stddev = _update_renorm_variable(self.renorm_stddev, - self.renorm_stddev_weight, - stddev) - # Make sqrt(moving_variance + epsilon) = new_stddev. - new_variance = math_ops.square(new_stddev) - self.epsilon + # TODO(yuefengz): colocate the operations + new_mean = _update_renorm_variable(self.renorm_mean, + self.renorm_mean_weight, mean) + new_stddev = _update_renorm_variable(self.renorm_stddev, + self.renorm_stddev_weight, stddev) + # Make sqrt(moving_variance + epsilon) = new_stddev. + new_variance = math_ops.square(new_stddev) - self.epsilon return (r, d, new_mean, new_variance) |