diff options
Diffstat (limited to 'tensorflow/python/keras/layers/normalization.py')
-rw-r--r-- | tensorflow/python/keras/layers/normalization.py | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py index a7835bc0a2..cd26e04c39 100644 --- a/tensorflow/python/keras/layers/normalization.py +++ b/tensorflow/python/keras/layers/normalization.py @@ -36,7 +36,7 @@ from tensorflow.python.ops import nn from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context from tensorflow.python.util.tf_export import tf_export @@ -345,16 +345,16 @@ class BatchNormalization(Layer): aggregation=variable_scope.VariableAggregation.MEAN) return var - with distribute_lib.get_distribution_strategy().colocate_vars_with( - self.moving_mean): + with distribution_strategy_context.get_distribution_strategy( + ).colocate_vars_with(self.moving_mean): self.renorm_mean = _renorm_variable('renorm_mean', param_shape) self.renorm_mean_weight = _renorm_variable('renorm_mean_weight', ()) # We initialize renorm_stddev to 0, and maintain the (0-initialized) # renorm_stddev_weight. This allows us to (1) mix the average # stddev with the minibatch stddev early in training, and (2) compute # the unbiased average stddev by dividing renorm_stddev by the weight. - with distribute_lib.get_distribution_strategy().colocate_vars_with( - self.moving_variance): + with distribution_strategy_context.get_distribution_strategy( + ).colocate_vars_with(self.moving_variance): self.renorm_stddev = _renorm_variable('renorm_stddev', param_shape) self.renorm_stddev_weight = _renorm_variable('renorm_stddev_weight', ()) |