aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/layers/normalization.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/layers/normalization.py')
-rw-r--r--tensorflow/python/keras/layers/normalization.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py
index a7835bc0a2..cd26e04c39 100644
--- a/tensorflow/python/keras/layers/normalization.py
+++ b/tensorflow/python/keras/layers/normalization.py
@@ -36,7 +36,7 @@ from tensorflow.python.ops import nn
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
-from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.util.tf_export import tf_export
@@ -345,16 +345,16 @@ class BatchNormalization(Layer):
aggregation=variable_scope.VariableAggregation.MEAN)
return var
- with distribute_lib.get_distribution_strategy().colocate_vars_with(
- self.moving_mean):
+ with distribution_strategy_context.get_distribution_strategy(
+ ).colocate_vars_with(self.moving_mean):
self.renorm_mean = _renorm_variable('renorm_mean', param_shape)
self.renorm_mean_weight = _renorm_variable('renorm_mean_weight', ())
# We initialize renorm_stddev to 0, and maintain the (0-initialized)
# renorm_stddev_weight. This allows us to (1) mix the average
# stddev with the minibatch stddev early in training, and (2) compute
# the unbiased average stddev by dividing renorm_stddev by the weight.
- with distribute_lib.get_distribution_strategy().colocate_vars_with(
- self.moving_variance):
+ with distribution_strategy_context.get_distribution_strategy(
+ ).colocate_vars_with(self.moving_variance):
self.renorm_stddev = _renorm_variable('renorm_stddev', param_shape)
self.renorm_stddev_weight = _renorm_variable('renorm_stddev_weight',
())