diff options
Diffstat (limited to 'tensorflow/python/estimator/keras.py')
-rw-r--r-- | tensorflow/python/estimator/keras.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/tensorflow/python/estimator/keras.py b/tensorflow/python/estimator/keras.py index a5f07fea3b..e4ce5339d0 100644 --- a/tensorflow/python/estimator/keras.py +++ b/tensorflow/python/estimator/keras.py @@ -43,7 +43,7 @@ from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import signature_constants from tensorflow.python.training import checkpoint_management -from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import optimizer as tf_optimizer_module from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import training_util @@ -361,7 +361,7 @@ def _create_keras_model_fn(keras_model, custom_objects=None): """model_fn for keras Estimator.""" # Raise an error when users use DistributionStrategy with native Keras # optimizers. Currently we only support native TensorFlow optimizers. - if distribute_lib.has_distribution_strategy() and \ + if distribution_strategy_context.has_distribution_strategy() and \ not isinstance(keras_model.optimizer, (tf_optimizer_module.Optimizer, optimizers.TFOptimizer)): raise ValueError('Only TensorFlow native optimizers are supported with ' @@ -373,7 +373,7 @@ def _create_keras_model_fn(keras_model, custom_objects=None): # We need to make sure that the output names of the last layer in the model # is the same for each of the cloned models. This is required for mirrored # strategy when we call regroup. - if distribute_lib.has_distribution_strategy(): + if distribution_strategy_context.has_distribution_strategy(): for name in model.output_names: name = re.compile(r'_\d$').sub('', name) model_output_names.append(name) |