aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/keras.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/estimator/keras.py')
-rw-r--r--tensorflow/python/estimator/keras.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/tensorflow/python/estimator/keras.py b/tensorflow/python/estimator/keras.py
index a5f07fea3b..e4ce5339d0 100644
--- a/tensorflow/python/estimator/keras.py
+++ b/tensorflow/python/estimator/keras.py
@@ -43,7 +43,7 @@ from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.training import checkpoint_management
-from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import optimizer as tf_optimizer_module
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import training_util
@@ -361,7 +361,7 @@ def _create_keras_model_fn(keras_model, custom_objects=None):
"""model_fn for keras Estimator."""
# Raise an error when users use DistributionStrategy with native Keras
# optimizers. Currently we only support native TensorFlow optimizers.
- if distribute_lib.has_distribution_strategy() and \
+ if distribution_strategy_context.has_distribution_strategy() and \
not isinstance(keras_model.optimizer,
(tf_optimizer_module.Optimizer, optimizers.TFOptimizer)):
raise ValueError('Only TensorFlow native optimizers are supported with '
@@ -373,7 +373,7 @@ def _create_keras_model_fn(keras_model, custom_objects=None):
# We need to make sure that the output names of the last layer in the model
# is the same for each of the cloned models. This is required for mirrored
# strategy when we call regroup.
- if distribute_lib.has_distribution_strategy():
+ if distribution_strategy_context.has_distribution_strategy():
for name in model.output_names:
name = re.compile(r'_\d$').sub('', name)
model_output_names.append(name)