aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/layers
diff options
context:
space:
mode:
authorGravatar Yuefeng Zhou <yuefengz@google.com>2018-03-29 13:22:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-29 13:25:02 -0700
commiteb2be37c12ae2b6c996f3f4c064e3d10f9565eab (patch)
tree0cfdd4f1654b66202754601b510b0a28304a1d2c /tensorflow/python/layers
parenta259ba951d3af9f62a0f95a881abf9ebaa45782b (diff)
Internal change.
PiperOrigin-RevId: 190976338
Diffstat (limited to 'tensorflow/python/layers')
-rw-r--r--tensorflow/python/layers/normalization.py76
1 files changed, 36 insertions, 40 deletions
diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py
index 29fb92ccb5..83b201e642 100644
--- a/tensorflow/python/layers/normalization.py
+++ b/tensorflow/python/layers/normalization.py
@@ -32,12 +32,12 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.layers import base
from tensorflow.python.layers import utils
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
-from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import resource_variable_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import init_ops
from tensorflow.python.ops import state_ops
+from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import moving_averages
from tensorflow.python.util.tf_export import tf_export
@@ -178,6 +178,11 @@ class BatchNormalization(base.Layer):
self.renorm_clipping = renorm_clipping
self.renorm_momentum = renorm_momentum
+ def _add_tower_local_variable(self, *args, **kwargs):
+ tower_context = distribute_lib.get_tower_context()
+ with tower_context.tower_local_var_scope('mean'):
+ return self.add_variable(*args, **kwargs)
+
def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape)
if not input_shape.ndims:
@@ -305,14 +310,14 @@ class BatchNormalization(base.Layer):
self._scope.set_partitioner(None)
else:
partitioner = None
- self.moving_mean = self.add_variable(
+ self.moving_mean = self._add_tower_local_variable(
name='moving_mean',
shape=param_shape,
dtype=param_dtype,
initializer=self.moving_mean_initializer,
trainable=False)
- self.moving_variance = self.add_variable(
+ self.moving_variance = self._add_tower_local_variable(
name='moving_variance',
shape=param_shape,
dtype=param_dtype,
@@ -328,7 +333,7 @@ class BatchNormalization(base.Layer):
# stack to be cleared. The nested ones use a `lambda` to set the desired
# device and ignore any devices that may be set by the custom getter.
def _renorm_variable(name, shape):
- var = self.add_variable(
+ var = self._add_tower_local_variable(
name=name,
shape=shape,
dtype=param_dtype,
@@ -336,24 +341,19 @@ class BatchNormalization(base.Layer):
trainable=False)
return var
- with ops.device(None):
- device = (
- self.moving_mean.device if context.executing_eagerly() else
- (lambda _: self.moving_mean.device))
- with ops.device(device):
- self.renorm_mean = _renorm_variable('renorm_mean', param_shape)
- self.renorm_mean_weight = _renorm_variable('renorm_mean_weight', ())
- # We initialize renorm_stddev to 0, and maintain the (0-initialized)
- # renorm_stddev_weight. This allows us to (1) mix the average
- # stddev with the minibatch stddev early in training, and (2) compute
- # the unbiased average stddev by dividing renorm_stddev by the weight.
- device = (
- self.moving_variance.device if context.executing_eagerly() else
- (lambda _: self.moving_variance.device))
- with ops.device(device):
- self.renorm_stddev = _renorm_variable('renorm_stddev', param_shape)
- self.renorm_stddev_weight = _renorm_variable(
- 'renorm_stddev_weight', ())
+ with distribute_lib.get_distribution_strategy().colocate_vars_with(
+ self.moving_mean):
+ self.renorm_mean = _renorm_variable('renorm_mean', param_shape)
+ self.renorm_mean_weight = _renorm_variable('renorm_mean_weight', ())
+ # We initialize renorm_stddev to 0, and maintain the (0-initialized)
+ # renorm_stddev_weight. This allows us to (1) mix the average
+ # stddev with the minibatch stddev early in training, and (2) compute
+ # the unbiased average stddev by dividing renorm_stddev by the weight.
+ with distribute_lib.get_distribution_strategy().colocate_vars_with(
+ self.moving_variance):
+ self.renorm_stddev = _renorm_variable('renorm_stddev', param_shape)
+ self.renorm_stddev_weight = _renorm_variable('renorm_stddev_weight',
+ ())
finally:
if partitioner:
self._scope.set_partitioner(partitioner)
@@ -362,12 +362,11 @@ class BatchNormalization(base.Layer):
def _assign_moving_average(self, variable, value, momentum):
with ops.name_scope(None, 'AssignMovingAvg',
[variable, value, momentum]) as scope:
- with ops.colocate_with(variable):
- decay = ops.convert_to_tensor(1.0 - momentum, name='decay')
- if decay.dtype != variable.dtype.base_dtype:
- decay = math_ops.cast(decay, variable.dtype.base_dtype)
- update_delta = (variable - value) * decay
- return state_ops.assign_sub(variable, update_delta, name=scope)
+ decay = ops.convert_to_tensor(1.0 - momentum, name='decay')
+ if decay.dtype != variable.dtype.base_dtype:
+ decay = math_ops.cast(decay, variable.dtype.base_dtype)
+ update_delta = (variable - value) * decay
+ return state_ops.assign_sub(variable, update_delta, name=scope)
def _fused_batch_norm(self, inputs, training):
"""Returns the output of fused batch norm."""
@@ -473,16 +472,13 @@ class BatchNormalization(base.Layer):
return array_ops.identity(var)
return utils.smart_cond(training, _do_update, _fake_update)
- with ops.colocate_with(self.moving_mean):
- new_mean = _update_renorm_variable(self.renorm_mean,
- self.renorm_mean_weight,
- mean)
- with ops.colocate_with(self.moving_variance):
- new_stddev = _update_renorm_variable(self.renorm_stddev,
- self.renorm_stddev_weight,
- stddev)
- # Make sqrt(moving_variance + epsilon) = new_stddev.
- new_variance = math_ops.square(new_stddev) - self.epsilon
+ # TODO(yuefengz): colocate the operations
+ new_mean = _update_renorm_variable(self.renorm_mean,
+ self.renorm_mean_weight, mean)
+ new_stddev = _update_renorm_variable(self.renorm_stddev,
+ self.renorm_stddev_weight, stddev)
+ # Make sqrt(moving_variance + epsilon) = new_stddev.
+ new_variance = math_ops.square(new_stddev) - self.epsilon
return (r, d, new_mean, new_variance)